diff --git a/VeOmni b/VeOmni index 600fe6d..8ca09d7 160000 --- a/VeOmni +++ b/VeOmni @@ -1 +1 @@ -Subproject commit 600fe6d7442392fd3ddefad4b6c8d3c0002fed1c +Subproject commit 8ca09d7c87f06ee7c0f69b0ca0c9e9a6b37f2280 diff --git a/configs/sft/llada2_flash_bd_sft.yaml b/configs/sft/llada2_flash_bd_sft.yaml index 603d4bd..11232b2 100644 --- a/configs/sft/llada2_flash_bd_sft.yaml +++ b/configs/sft/llada2_flash_bd_sft.yaml @@ -2,58 +2,75 @@ model: config_path: ./configs/model_configs/llada2_flash model_path: ./LLaDA2.0-flash-preview-moe-merge tokenizer_path: ./LLaDA2.0-flash-preview-moe-merge - attn_implementation: sdpa - moe_implementation: fused + ops_implementation: + attn_implementation: sdpa + moe_implementation: fused_triton + cross_entropy_loss_implementation: eager + rms_norm_implementation: eager + swiglu_mlp_implementation: eager + rotary_pos_emb_implementation: eager + rotary_pos_emb_vision_implementation: eager + load_balancing_loss_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager data: train_path: ./gsm8k_datasets/gsm8k_train.jsonl data_type: conversation datasets_type: mapping - dataloader_type: native max_seq_len: 2048 text_keys: messages noise_range_low: 0.3 noise_range_high: 0.8 - num_workers: 16 + dataloader: + type: native + num_workers: 16 + drop_last: true + pin_memory: true train: - output_dir: ./llada2_flash_bd_sft_outputs - data_parallel_mode: fsdp2 - tensor_parallel_size: 1 - ulysses_parallel_size: 1 - expert_parallel_size: 1 + dyn_bsz: false global_batch_size: 16 micro_batch_size: 1 num_train_epochs: 1 - rmpad: false - rmpad_with_pos_ids: false bsz_warmup_ratio: 0.007 - dyn_bsz_margin: 0 - dyn_bsz_buffer_size: 200 - optimizer: adamw - beta1: 0.9 - beta2: 0.999 - lr: 1.0e-5 - lr_warmup_ratio: 0.03 - lr_decay_style: cosine - lr_decay_ratio: 1.0 - weight_decay: 0.1 - max_grad_norm: 1.0 - enable_mixed_precision: true - enable_gradient_checkpointing: true - enable_full_shard: true - enable_fsdp_offload: true - enable_activation_offload: false init_device: meta broadcast_model_weights_from_rank0: true enable_full_determinism: false empty_cache_steps: 500 - ckpt_manager: dcp - load_checkpoint_path: "" - save_epochs: 1 - save_hf_weights: true + beta1: 0.9 + beta2: 0.999 block_diffusion_mode: true block_size: 32 same_token_labels: true - use_wandb: false # or you can set `wandb_project` and `wandb_name` to trace your training - log_steps: 1 + optimizer: + type: adamw + lr: 1.0e-5 + lr_warmup_ratio: 0.03 + lr_decay_style: cosine + lr_decay_ratio: 1.0 + weight_decay: 0.1 + max_grad_norm: 1.0 + accelerator: + tp_size: 1 + ep_size: 1 + pp_size: 1 + ulysses_size: 1 + cp_size: 1 + fsdp_config: + fsdp_mode: fsdp2 + offload: true + mixed_precision: + enable: true + gradient_checkpointing: + enable: true + enable_reentrant: false + checkpoint: + output_dir: ./llada2_flash_bd_sft_outputs + manager: dcp + load_path: null + save_epochs: 1 + save_hf_weights: true + wandb: + enable: false diff --git a/configs/sft/llada2_flash_bd_sft_npu.yaml b/configs/sft/llada2_flash_bd_sft_npu.yaml new file mode 100644 index 0000000..615182e --- /dev/null +++ b/configs/sft/llada2_flash_bd_sft_npu.yaml @@ -0,0 +1,76 @@ +model: + config_path: ./configs/model_configs/llada2_flash + model_path: ./LLaDA2.0-flash-preview-moe-merge + tokenizer_path: ./LLaDA2.0-flash-preview-moe-merge + ops_implementation: + attn_implementation: sdpa + moe_implementation: fused_npu + cross_entropy_loss_implementation: npu + rms_norm_implementation: npu + swiglu_mlp_implementation: eager + rotary_pos_emb_implementation: npu + rotary_pos_emb_vision_implementation: eager + load_balancing_loss_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager + +data: + train_path: ./gsm8k_datasets/gsm8k_train.jsonl + data_type: conversation + datasets_type: mapping + max_seq_len: 2048 + text_keys: messages + noise_range_low: 0.3 + noise_range_high: 0.8 + dataloader: + type: native + num_workers: 16 + drop_last: true + pin_memory: true + +train: + dyn_bsz: false + global_batch_size: 16 + micro_batch_size: 1 + num_train_epochs: 1 + bsz_warmup_ratio: 0.007 + init_device: meta + broadcast_model_weights_from_rank0: true + enable_full_determinism: false + empty_cache_steps: 500 + beta1: 0.9 + beta2: 0.999 + block_diffusion_mode: true + block_size: 32 + same_token_labels: true + optimizer: + type: adamw + lr: 1.0e-5 + lr_warmup_ratio: 0.03 + lr_decay_style: cosine + lr_decay_ratio: 1.0 + weight_decay: 0.1 + max_grad_norm: 1.0 + accelerator: + tp_size: 1 + ep_size: 1 + pp_size: 1 + ulysses_size: 1 + cp_size: 1 + fsdp_config: + fsdp_mode: fsdp2 + offload: true + mixed_precision: + enable: true + gradient_checkpointing: + enable: true + enable_reentrant: false + checkpoint: + output_dir: ./llada2_flash_bd_sft_npu_outputs + manager: dcp + load_path: null + save_epochs: 1 + save_hf_weights: true + wandb: + enable: false diff --git a/configs/sft/llada2_mini_bd_sft.yaml b/configs/sft/llada2_mini_bd_sft.yaml index 5e188e0..8d126e8 100644 --- a/configs/sft/llada2_mini_bd_sft.yaml +++ b/configs/sft/llada2_mini_bd_sft.yaml @@ -2,58 +2,75 @@ model: config_path: ./configs/model_configs/llada2_mini model_path: ./LLaDA2.0-mini-preview-moe-merge tokenizer_path: ./LLaDA2.0-mini-preview-moe-merge - attn_implementation: sdpa - moe_implementation: fused + ops_implementation: + attn_implementation: sdpa + moe_implementation: fused_triton + cross_entropy_loss_implementation: eager + rms_norm_implementation: eager + swiglu_mlp_implementation: eager + rotary_pos_emb_implementation: eager + rotary_pos_emb_vision_implementation: eager + load_balancing_loss_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager data: train_path: ./gsm8k_datasets/gsm8k_train.jsonl data_type: conversation datasets_type: mapping - dataloader_type: native max_seq_len: 2048 text_keys: messages noise_range_low: 0.3 noise_range_high: 0.8 - num_workers: 16 + dataloader: + type: native + num_workers: 16 + drop_last: true + pin_memory: true train: - output_dir: ./llada2_mini_bd_sft_outputs - data_parallel_mode: fsdp2 - tensor_parallel_size: 1 - ulysses_parallel_size: 1 - expert_parallel_size: 1 + dyn_bsz: false global_batch_size: 8 micro_batch_size: 1 num_train_epochs: 1 - rmpad: false - rmpad_with_pos_ids: false bsz_warmup_ratio: 0.007 - dyn_bsz_margin: 0 - dyn_bsz_buffer_size: 200 - optimizer: adamw - beta1: 0.9 - beta2: 0.999 - lr: 1.0e-5 - lr_warmup_ratio: 0.03 - lr_decay_style: cosine - lr_decay_ratio: 1.0 - weight_decay: 0.1 - max_grad_norm: 1.0 - enable_mixed_precision: true - enable_gradient_checkpointing: true - enable_full_shard: true - enable_fsdp_offload: true - enable_activation_offload: false init_device: meta broadcast_model_weights_from_rank0: true enable_full_determinism: false empty_cache_steps: 500 - ckpt_manager: dcp - load_checkpoint_path: "" - save_epochs: 1 - save_hf_weights: true + beta1: 0.9 + beta2: 0.999 block_diffusion_mode: true block_size: 32 same_token_labels: true - use_wandb: false # or you can set `wandb_project` and `wandb_name` to trace your training - log_steps: 1 + optimizer: + type: adamw + lr: 1.0e-5 + lr_warmup_ratio: 0.03 + lr_decay_style: cosine + lr_decay_ratio: 1.0 + weight_decay: 0.1 + max_grad_norm: 1.0 + accelerator: + tp_size: 1 + ep_size: 1 + pp_size: 1 + ulysses_size: 1 + cp_size: 1 + fsdp_config: + fsdp_mode: fsdp2 + offload: true + mixed_precision: + enable: true + gradient_checkpointing: + enable: true + enable_reentrant: false + checkpoint: + output_dir: ./llada2_mini_bd_sft_outputs + manager: dcp + load_path: null + save_epochs: 1 + save_hf_weights: true + wandb: + enable: false diff --git a/configs/sft/llada2_mini_bd_sft_npu.yaml b/configs/sft/llada2_mini_bd_sft_npu.yaml new file mode 100644 index 0000000..35faf5c --- /dev/null +++ b/configs/sft/llada2_mini_bd_sft_npu.yaml @@ -0,0 +1,76 @@ +model: + config_path: ./configs/model_configs/llada2_mini + model_path: ./LLaDA2.0-mini-preview-moe-merge + tokenizer_path: ./LLaDA2.0-mini-preview-moe-merge + ops_implementation: + attn_implementation: sdpa + moe_implementation: fused_npu + cross_entropy_loss_implementation: npu + rms_norm_implementation: npu + swiglu_mlp_implementation: eager + rotary_pos_emb_implementation: npu + rotary_pos_emb_vision_implementation: eager + load_balancing_loss_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager + +data: + train_path: ./gsm8k_datasets/gsm8k_train.jsonl + data_type: conversation + datasets_type: mapping + max_seq_len: 2048 + text_keys: messages + noise_range_low: 0.3 + noise_range_high: 0.8 + dataloader: + type: native + num_workers: 16 + drop_last: true + pin_memory: true + +train: + dyn_bsz: false + global_batch_size: 8 + micro_batch_size: 1 + num_train_epochs: 1 + bsz_warmup_ratio: 0.007 + init_device: meta + broadcast_model_weights_from_rank0: true + enable_full_determinism: false + empty_cache_steps: 500 + beta1: 0.9 + beta2: 0.999 + block_diffusion_mode: true + block_size: 32 + same_token_labels: true + optimizer: + type: adamw + lr: 1.0e-5 + lr_warmup_ratio: 0.03 + lr_decay_style: cosine + lr_decay_ratio: 1.0 + weight_decay: 0.1 + max_grad_norm: 1.0 + accelerator: + tp_size: 1 + ep_size: 1 + pp_size: 1 + ulysses_size: 1 + cp_size: 1 + fsdp_config: + fsdp_mode: fsdp2 + offload: true + mixed_precision: + enable: true + gradient_checkpointing: + enable: true + enable_reentrant: false + checkpoint: + output_dir: ./llada2_mini_bd_sft_npu_outputs + manager: dcp + load_path: null + save_epochs: 1 + save_hf_weights: true + wandb: + enable: false diff --git a/models/llada2_moe/__init__.py b/models/llada2_moe/__init__.py index 1f22972..a83d5e4 100644 --- a/models/llada2_moe/__init__.py +++ b/models/llada2_moe/__init__.py @@ -1,3 +1,28 @@ -from .modeling_llada2_moe import LLaDA2MoeModelLM +from veomni.models.loader import MODEL_CONFIG_REGISTRY, MODELING_REGISTRY -ModelClass = LLaDA2MoeModelLM \ No newline at end of file +from .configuration_llada2_moe import LLaDA2MoeConfig +from .modeling_llada2_moe import LLaDA2MoeModel, LLaDA2MoeModelLM, LLaDA2MoePreTrainedModel + + +@MODEL_CONFIG_REGISTRY.register("llada2_moe_veomni") +def register_llada2_moe_config(): + return LLaDA2MoeConfig + + +@MODELING_REGISTRY.register("llada2_moe_veomni") +def register_llada2_moe_modeling(architecture: str): + if architecture and ("ForCausalLM" in architecture or "ModelLM" in architecture): + return LLaDA2MoeModelLM + if architecture and "Model" in architecture: + return LLaDA2MoeModel + return LLaDA2MoeModelLM + +ModelClass = LLaDA2MoeModelLM + +__all__ = [ + "LLaDA2MoeConfig", + "LLaDA2MoeModel", + "LLaDA2MoeModelLM", + "LLaDA2MoePreTrainedModel", + "ModelClass", +] diff --git a/models/llada2_moe/modeling_llada2_moe.py b/models/llada2_moe/modeling_llada2_moe.py index 820dd42..c0c0f22 100644 --- a/models/llada2_moe/modeling_llada2_moe.py +++ b/models/llada2_moe/modeling_llada2_moe.py @@ -43,21 +43,35 @@ ) from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from transformers.modeling_utils import PreTrainedModel, ALL_ATTENTION_FUNCTIONS -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13 +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +try: + from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +except ImportError: + def is_torch_greater_or_equal_than_1_13(): + return True from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from transformers.utils.import_utils import is_torch_fx_available +try: + from transformers.utils.import_utils import is_torch_fx_available +except ImportError: + def is_torch_fx_available(): + return True from .configuration_llada2_moe import LLaDA2MoeConfig from transformers.generation.utils import GenerationMixin -from veomni.ops import causallm_loss_function, fused_moe_forward +from veomni.ops import fused_moe_forward from veomni.distributed.parallel_state import get_parallel_state -from veomni.utils.import_utils import is_liger_kernel_available +from veomni.utils.import_utils import is_liger_kernel_available, is_torch_npu_available from veomni.utils import logging -if is_liger_kernel_available(): + +def _liger_kernel_enabled(): + return is_liger_kernel_available() and not is_torch_npu_available() + + +if _liger_kernel_enabled(): from liger_kernel.ops.swiglu import LigerSiLUMulFunction from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb @@ -74,6 +88,40 @@ logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "LLaDA2MoeConfig" +_LLADA2_MOE_OPS_PATCHED_IMPL = None + + +def _apply_llada2_moe_ops_config(): + global _LLADA2_MOE_OPS_PATCHED_IMPL + + try: + from veomni.ops.config.singleton import get_ops_config + from veomni.ops.kernels.moe import apply_veomni_fused_moe_patch + except Exception: + return + + ops_config = get_ops_config() + if ops_config is None: + return + + moe_impl = getattr(ops_config, "moe_implementation", "eager") + if moe_impl == "eager" or moe_impl == _LLADA2_MOE_OPS_PATCHED_IMPL: + return + + apply_veomni_fused_moe_patch(fused_moe_kernel=moe_impl.removeprefix("fused_")) + _LLADA2_MOE_OPS_PATCHED_IMPL = moe_impl + + +def _get_llada2_moe_implementation(): + try: + from veomni.ops.config.singleton import get_ops_config + except Exception: + return "eager" + + ops_config = get_ops_config() + if ops_config is None: + return "eager" + return getattr(ops_config, "moe_implementation", "eager") def _get_unpad_data(attention_mask): @@ -108,6 +156,24 @@ def forward(self, hidden_states): ALL_LAYERNORM_LAYERS.append(LLaDA2MoeRMSNorm) +def _llada2_default_rope_init(config: LLaDA2MoeConfig, device=None): + head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / ( + config.rope_theta ** (torch.arange(0, dim, 2, dtype=torch.int64, device=device).float() / dim) + ) + return inv_freq, 1.0 + + +def _get_rope_init_fn(rope_type: str): + if rope_type in ROPE_INIT_FUNCTIONS: + return ROPE_INIT_FUNCTIONS[rope_type] + if rope_type == "default": + return _llada2_default_rope_init + raise KeyError(rope_type) + + class LLaDA2MoeRotaryEmbedding(nn.Module): def __init__(self, config: LLaDA2MoeConfig, device=None): super().__init__() @@ -120,7 +186,7 @@ def __init__(self, config: LLaDA2MoeConfig, device=None): self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + self.rope_init_fn = _get_rope_init_fn(self.rope_type) inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -202,7 +268,7 @@ def __init__(self, config: LLaDA2MoeConfig, intermediate_size: int): self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - if is_liger_kernel_available(): + if _liger_kernel_enabled(): return self.down_proj(LigerSiLUMulFunction.apply(self.gate_proj(x), self.up_proj(x))) else: return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) @@ -275,6 +341,7 @@ def forward(self, hidden_states): class LLaDA2MoeExperts(nn.Module): def __init__(self, config): super().__init__() + _apply_llada2_moe_ops_config() self.num_experts = config.num_experts self.hidden_dim = config.hidden_size self.intermediate_size = config.moe_intermediate_size @@ -306,9 +373,8 @@ def forward(self, hidden_states, expert_idx=None, routing_weights=None, selected ) out = fused_moe_forward( - module=self, num_experts=self.num_experts, - routing_weights=routing_weights, + routing_weights=routing_weights.to(hidden_states.dtype), selected_experts=selected_experts, hidden_states=hidden_states, fc1_1_weight=self.gate_proj, @@ -343,7 +409,7 @@ def __init__(self, config: LLaDA2MoeConfig): self._setup_experts() self.gate = LLaDA2MoeGate(config) - if config.num_shared_experts is not None: + if config.num_shared_experts: self.shared_experts = LLaDA2MoeMLP( config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts ) @@ -364,13 +430,24 @@ def _fuse_moe_forward(self, hidden_states): bsz, seq_len, h = hidden_states.shape topk_idx, topk_weight, router_logits = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - y = self.experts( - hidden_states, routing_weights=topk_weight, selected_experts=topk_idx - ).reshape(bsz, seq_len, h) - if self.config.num_shared_experts is not None: + if _get_llada2_moe_implementation() == "eager": + y = self._stacked_experts_eager_forward(hidden_states, topk_idx, topk_weight).reshape(bsz, seq_len, h) + else: + y = self.experts( + hidden_states, routing_weights=topk_weight, selected_experts=topk_idx + ).reshape(bsz, seq_len, h) + if self.config.num_shared_experts: y = y + self.shared_experts(identity) return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1)) + def _stacked_experts_eager_forward(self, hidden_states, topk_idx, topk_weight): + flat_topk_idx = topk_idx.view(-1) + hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0) + y = torch.empty_like(hidden_states) + for i in range(self.config.num_experts): + y[flat_topk_idx == i] = self.experts(hidden_states[flat_topk_idx == i], expert_idx=i) + return (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1).to(hidden_states.dtype) + def _forward(self, hidden_states): identity = hidden_states bsz, seq_len, h = hidden_states.shape @@ -386,7 +463,7 @@ def _forward(self, hidden_states): y = y.to(hidden_states.dtype).view(bsz, seq_len, h) else: y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(bsz, seq_len, h) - if self.config.num_shared_experts is not None: + if self.config.num_shared_experts: y = y + self.shared_experts(identity) return y, (router_logits.view(bsz, seq_len, -1), topk_idx.view(bsz, seq_len, -1)) @@ -1566,7 +1643,7 @@ def apply_rotary_pos_emb_llada2_moe(q, k, cos, sin, position_ids, unsqueeze_dim= return q_embed, k_embed -if is_liger_kernel_available(): +if _liger_kernel_enabled(): apply_rotary_pos_emb = apply_rotary_pos_emb_llada2_moe LLaDA2MoeRMSNorm = LigerRMSNorm logger.info_rank0("Apply liger kernel to LLaDA2Moe") diff --git a/models/llada2_moe/parallel_plan.py b/models/llada2_moe/parallel_plan.py index fd31fd3..391467a 100644 --- a/models/llada2_moe/parallel_plan.py +++ b/models/llada2_moe/parallel_plan.py @@ -10,6 +10,6 @@ def get_parallel_plan(): "model.layers.*.mlp.experts.down_proj": Shard(0), } parallel_plan = ParallelPlan( - ep_plan=ep_plan, + extra_parallel_plan={"ep": ep_plan}, ) return parallel_plan diff --git a/tasks/dataset/__init__.py b/tasks/dataset/__init__.py index 8c8e893..a866cc9 100644 --- a/tasks/dataset/__init__.py +++ b/tasks/dataset/__init__.py @@ -3,4 +3,4 @@ __all__ = [ "build_local_dataset" -] \ No newline at end of file +] diff --git a/tasks/train_llada2_bd.py b/tasks/train_llada2_bd.py index ec9ccdb..f00a321 100644 --- a/tasks/train_llada2_bd.py +++ b/tasks/train_llada2_bd.py @@ -1,572 +1,8 @@ -import json -import os -import time -from dataclasses import asdict, dataclass, field -from functools import partial -from typing import Any, Dict, List, Literal, Tuple, Optional - -import torch -import torch.distributed as dist -import wandb -from tqdm import trange - -from veomni.checkpoint import build_checkpointer, ckpt_to_state_dict -from veomni.data import ( - build_dataloader, - build_iterative_dataset, - build_mapping_dataset, -) -from veomni.distributed.offloading import build_activation_offloading_context -from veomni.distributed.parallel_state import get_parallel_state, init_parallel_state -from veomni.distributed.torch_parallelize import build_parallelize_model -from veomni.models import build_foundation_model, build_tokenizer, save_model_assets, save_model_weights -from veomni.optim import build_lr_scheduler, build_optimizer -from veomni.utils import helper -from veomni.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args, save_args -from veomni.utils.device import ( - get_device_type, - get_nccl_backend, - get_torch_device, - synchronize, -) -from veomni.utils.dist_utils import all_reduce -from veomni.models.registry import ModelRegistry -ModelRegistry.register_modeling_path("models.llada2_moe") -from dataset.data_transform import process_mdm_tokenized_example, process_mdm_sft_example -from dataset import build_local_dataset - - -logger = helper.create_logger(__name__) - -@dataclass -class LLaDA2ModelArguments(ModelArguments): - attn_implementation: Optional[Literal["eager", "sdpa", "flex_attention"]] = field( - default="sdpa", - metadata={"help": "Attention implementation to use."}, - ) - - -@dataclass -class LLaDA2DataArguments(DataArguments): - data_type: Literal["conversation", "tokenid"] = field( - default="conversation", - metadata={"help": "Type of the training data."}, - ) - datasets_type: Literal["mapping", "local"] = field( - default="mapping", - metadata={"help": "Type of the datasets."}, - ) - text_keys: str = field( - default="messages", - metadata={"help": "Key to get text from the training data."}, - ) - noise_range_low: float = field( - default=0.3, - metadata={"help": "Noise level for random flip input_ids to mask_ids"} - ) - noise_range_high: float = field( - default=0.8, - metadata={"help": "Noise level for random flip input_ids to mask_ids"} - ) - - def __post_init__(self): - super().__post_init__() - if self.noise_range_low > self.noise_range_high: - raise ValueError( - f"noise_range_low ({self.noise_range_low}) " - f"cannot be greater than noise_range_high ({self.noise_range_high})." - ) - - if not (0.0 <= self.noise_range_low <= 1.0): - raise ValueError( - f"noise_range_low must be between 0.0 and 1.0, but got {self.noise_range_low}." - ) - - if not (0.0 <= self.noise_range_high <= 1.0): - raise ValueError( - f"noise_range_high must be between 0.0 and 1.0, but got {self.noise_range_high}." - ) - - -@dataclass -class LLaDA2TrainingArguments(TrainingArguments): - beta1: float = field( - default=0.9, - metadata={"help": "AdamW optimizer beta1."}, - ) - beta2: float = field( - default=0.999, - metadata={"help": "AdamW optimizer beta2"}, - ) - block_diffusion_mode: bool = field( - default=False, - metadata={"help": "If train MDM in block_diffusion mode. True: use block_diffusion, False: full_attention"} - ) - block_size: int = field( - default=32, - metadata={"help": "The block size for block diffusion block size"} - ) - same_token_labels: bool = field( - default=False, - metadata={"help": "If use same token location labels. True: no shift, False: use next-token prediction shift."} - ) - - -@dataclass -class Arguments: - model: "LLaDA2ModelArguments" = field(default_factory=LLaDA2ModelArguments) - data: "LLaDA2DataArguments" = field(default_factory=LLaDA2DataArguments) - train: "LLaDA2TrainingArguments" = field(default_factory=LLaDA2TrainingArguments) - - -def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): - """ - Constructs the specialized block diffusion attention mask for training - composed of three masks: - - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks - - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context - - **Block Causal Mask (M_BC)**: Attention to update x0 - - Args: - b, h: Batch and head indices (ignored for mask logic). - q_idx, kv_idx: Query and Key indices. - seq_len: Total sequence length. - block_size: Defines the block structure. - - Returns: - A boolean attention mask. - """ - - # Indicate whether token belongs to xt or x0 - x0_flag_q = (q_idx >= n) - x0_flag_kv = (kv_idx >= n) - - # Compute block indices - block_q = torch.where(x0_flag_q == 1, - (q_idx - n) // block_size, - q_idx // block_size) - block_kv = torch.where(x0_flag_kv == 1, - (kv_idx - n) // block_size, - kv_idx // block_size) - - # **1. Block Diagonal Mask (M_BD) ** - block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) - - # **2. Offset Block-Causal Mask (M_OBC) ** - offset_block_causal = ( - (block_q > block_kv) - & (x0_flag_kv == 1) - & (x0_flag_q == 0) - ) - - # **3. Block-Causal Mask (M_BC) ** - block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) - - # **4. Combine Masks ** - return block_diagonal | offset_block_causal | block_causal +from train_llada2_common import LLaDA2Arguments, run_llada2_training def main(): - dist.init_process_group(backend=get_nccl_backend()) - args = parse_args(Arguments) - logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}") - logger.info_rank0(json.dumps(asdict(args), indent=2)) - get_torch_device().set_device(f"{get_device_type()}:{args.train.local_rank}") - helper.set_seed(args.train.seed, args.train.enable_full_determinism) - if args.train.local_rank == 0: - helper.enable_third_party_logging() - - if args.train.global_rank == 0: - save_args(args, args.train.output_dir) - - Checkpointer = build_checkpointer(dist_backend=args.train.data_parallel_mode, ckpt_manager=args.train.ckpt_manager) - - init_parallel_state( - dp_size=args.train.data_parallel_size, - dp_replicate_size=args.train.data_parallel_replicate_size, - dp_shard_size=args.train.data_parallel_shard_size, - tp_size=args.train.tensor_parallel_size, - ep_size=args.train.expert_parallel_size, - pp_size=args.train.pipeline_parallel_size, - cp_size=args.train.context_parallel_size, - ulysses_size=args.train.ulysses_parallel_size, - dp_mode=args.train.data_parallel_mode, - ) - - logger.info_rank0("Prepare data") - tokenizer = build_tokenizer(args.model.tokenizer_path) - if args.data.data_type == "conversation": - if not tokenizer.chat_template: - raise ValueError(f"No chat template found in the tokenizer.") - - transform = partial( - process_mdm_sft_example, - tokenizer=tokenizer, - max_seq_len=args.data.max_seq_len, - text_keys=args.data.text_keys, - noise_range=(args.data.noise_range_low, args.data.noise_range_high), - mask_token_id=156895, - ) - elif args.data.data_type == "tokenid": - transform = partial( - process_mdm_tokenized_example, - max_seq_len=args.data.max_seq_len, - text_keys=args.data.text_keys, - noise_range=(args.data.noise_range_low, args.data.noise_range_high), - mask_token_id=156895, - ) - else: - raise NotImplementedError(f"Unsupported data type: {args.data.data_type}.") - - if args.data.dataloader_type == "native": - if args.data.datasets_type == "iterable": - logger.info_rank0("Start building iterative dataset") - train_dataset = build_iterative_dataset(args.data.train_path, transform=transform, seed=args.train.seed) - elif args.data.datasets_type == "mapping": - logger.info_rank0("Start building mapping dataset") - train_dataset = build_mapping_dataset(args.data.train_path, transform=transform) - elif args.data.datasets_type == "local": - logger.info_rank0("Start building local dataset") - train_dataset = build_local_dataset(args.data.train_path, transform=transform) - - dataset_length = None if not hasattr(train_dataset, "__len__") else len(train_dataset) - if args.data.datasets_type == "mapping" or args.data.datasets_type == "local": - dataset_length = dataset_length / args.train.data_parallel_size - args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size, dataset_length) - - train_dataloader = build_dataloader( - dataset=train_dataset, - micro_batch_size=args.train.micro_batch_size, - global_batch_size=args.train.global_batch_size, - dataloader_batch_size=args.train.dataloader_batch_size, - seed=args.train.seed, - max_seq_len=args.data.max_seq_len, - train_steps=args.train.train_steps, - rmpad=args.train.rmpad, - rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, - bsz_warmup_ratio=args.train.bsz_warmup_ratio, - bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken, - dyn_bsz_margin=args.train.dyn_bsz_margin, - dyn_bsz_buffer_size=args.train.dyn_bsz_buffer_size, - num_workers=args.data.num_workers, - drop_last=args.data.drop_last, - pin_memory=args.data.pin_memory, - prefetch_factor=args.data.prefetch_factor, - ) - else: - raise NotImplementedError(f"Unsupported dataloader type: {args.data.dataloader_type}.") - - logger.info_rank0("Prepare model") - model = build_foundation_model( - config_path=args.model.config_path, - weights_path=args.model.model_path, - torch_dtype="float32" if args.train.enable_mixed_precision else "bfloat16", - attn_implementation=args.model.attn_implementation, - moe_implementation=args.model.moe_implementation, - init_device=args.train.init_device, - force_use_huggingface=args.model.force_use_huggingface, - ) - model_config = model.config - helper.print_device_mem_info("VRAM usage after building model") - - get_optimizer_pre_hook = getattr(model, "get_optimizer_pre_hook", None) - model = build_parallelize_model( - model, - init_device=args.train.init_device, - weights_path=args.model.model_path, - enable_full_shard=args.train.enable_full_shard, - enable_mixed_precision=args.train.enable_mixed_precision, - enable_gradient_checkpointing=args.train.enable_gradient_checkpointing, - enable_fsdp_offload=args.train.enable_fsdp_offload, - basic_modules=model._no_split_modules + args.model.basic_modules, - enable_reentrant=args.train.enable_reentrant, - enable_forward_prefetch=args.train.enable_forward_prefetch, - broadcast_model_weights_from_rank0=args.train.broadcast_model_weights_from_rank0 - ) - - optimizer = build_optimizer( - model, - lr=args.train.lr, - betas=(args.train.beta1, args.train.beta2), - weight_decay=args.train.weight_decay, - fused=True, - optimizer_type=args.train.optimizer, - ) - - if get_optimizer_pre_hook is not None: - optimizer_pre_hook = get_optimizer_pre_hook(model, model_config, args.train.data_parallel_mode) - optimizer.register_step_pre_hook(optimizer_pre_hook) - - lr_scheduler = build_lr_scheduler( - optimizer, - train_steps=args.train.train_steps * args.train.num_train_epochs, - lr=args.train.lr, - lr_min=args.train.lr_min, - lr_decay_style=args.train.lr_decay_style, - lr_decay_ratio=args.train.lr_decay_ratio, - lr_warmup_ratio=args.train.lr_warmup_ratio, - lr_start=args.train.lr_start, - ) - - if args.train.global_rank == 0: - if args.train.use_wandb: - wandb.init( - project=args.train.wandb_project, - name=args.train.wandb_name, - config={**vars(args.model), **vars(args.data), **vars(args.train)}, # flatten dict - ) - - # save model_assets before training - model_assets = [model_config, tokenizer] - save_model_assets(args.train.model_assets_dir, model_assets) - - if args.train.profile_this_rank: - profiler = helper.create_profiler( - start_step=args.train.profile_start_step, - end_step=args.train.profile_end_step, - trace_dir=args.train.profile_trace_dir, - record_shapes=args.train.profile_record_shapes, - profile_memory=args.train.profile_profile_memory, - with_stack=args.train.profile_with_stack, - global_rank=args.train.global_rank, - ) - profiler.start() - - start_epoch, start_step, global_step = 0, 0, 0 - save_checkpoint_path = None - environ_meter = helper.EnvironMeter( - config=model_config, - global_batch_size=args.train.global_batch_size, - rmpad=args.train.rmpad, - rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, - empty_cache_steps=args.train.empty_cache_steps, - enable_multisource=args.data.enable_multisource, - dataloader=train_dataloader, - data_path=args.data.train_path, - ) - - if args.train.load_checkpoint_path: - state = {"model": model, "optimizer": optimizer, "extra_state": {}} # cannot be None - Checkpointer.load(args.train.load_checkpoint_path, state) - global_step = state["extra_state"]["global_step"] - start_epoch = global_step // args.train.train_steps - start_step = global_step % args.train.train_steps - lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"]) - train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) - environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) - torch.set_rng_state(state["extra_state"]["torch_rng_state"]) - if start_step == 0: # resume at the end of epoch - iter(train_dataloader) # clear resume state and prefetch data - - dist.barrier() - logger.info_rank0(f"Load distributed checkpoint from {args.train.load_checkpoint_path} successfully!") - - # Build block diffusion attention mask - if args.train.block_diffusion_mode: - bd_attn_full_len = args.data.max_seq_len * 2 - block_size = args.train.block_size - # NOTE: Boolean dtype block diffusion attention mask - block_diffusion_attn_mask_flag = block_diffusion_mask( - b=None, h=None, - q_idx=torch.arange(bd_attn_full_len)[:, None], - kv_idx=torch.arange(bd_attn_full_len)[None, :], - block_size=block_size, - n=args.data.max_seq_len - ).unsqueeze(0).unsqueeze(0) - - block_diffusion_attn_mask_prototype = torch.zeros_like( - block_diffusion_attn_mask_flag, - dtype=torch.float32 if args.train.enable_mixed_precision else torch.bfloat16 - ) - block_diffusion_attn_mask_prototype.masked_fill_(block_diffusion_attn_mask_flag.logical_not(), float("-inf")) - - helper.empty_cache() - model_fwd_context, model_bwd_context = build_activation_offloading_context( - args.train.enable_activation_offload, args.train.enable_gradient_checkpointing, args.train.activation_gpu_limit - ) - model.train() - logger.info( - f"rank{args.train.local_rank} Start training, train_steps: {args.train.train_steps}, epochs: {args.train.num_train_epochs}" - ) - for epoch in range(start_epoch, args.train.num_train_epochs): - if hasattr(train_dataloader, "set_epoch"): - train_dataloader.set_epoch(epoch) - - data_loader_tqdm = trange( - args.train.train_steps, - desc=f"Epoch {epoch + 1}/{args.train.num_train_epochs}", - total=args.train.train_steps, - initial=start_step, - disable=args.train.local_rank != 0, - ) - data_iterator = iter(train_dataloader) - for _ in range(start_step, args.train.train_steps): - global_step += 1 - - try: - micro_batches: List[Dict[str, Any]] = next(data_iterator) - except StopIteration: - logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.drop_last}") - break - - if global_step == 1: - helper.print_example(example=micro_batches[0], rank=args.train.local_rank) - - total_loss = 0 - synchronize() - start_time = time.time() - for micro_batch in micro_batches: - environ_meter.add(micro_batch) - if args.data.enable_multisource: - micro_batch.pop("ds_idx", None) - micro_batch.pop("source_name", None) - - if args.train.block_diffusion_mode: - noisy_input_ids = micro_batch["noisy_input_ids"] - clean_input_ids = micro_batch["input_ids"] - batch_size = noisy_input_ids.shape[0] - full_input_ids = torch.cat([noisy_input_ids, clean_input_ids], dim=1) - noisy_position_ids = torch.arange(noisy_input_ids.shape[1], device=get_device_type(), dtype=torch.long) - clean_position_ids = torch.arange(clean_input_ids.shape[1], device=get_device_type(), dtype=torch.long) - position_ids = torch.cat([noisy_position_ids, clean_position_ids], dim=0).unsqueeze(0).expand(batch_size, -1).clone() - micro_batch["input_ids"] = full_input_ids - micro_batch["position_ids"] = position_ids - micro_batch["attention_mask"] = block_diffusion_attn_mask_prototype.expand(batch_size, -1, -1, -1) - else: - micro_batch["attention_mask"] = None - - micro_batch = { - k: v.to(get_device_type(), non_blocking=True) if isinstance(v, torch.Tensor) else v - for k, v in micro_batch.items() - } - - labels = micro_batch.pop("labels", None) - - with model_fwd_context: - logits: "torch.Tensor" = model(**micro_batch, use_cache=False, output_router_logits=False).logits - if args.train.block_diffusion_mode: - noisy_logits = logits[:, :noisy_input_ids.shape[1]].contiguous() - else: - noisy_logits = logits - - if args.train.same_token_labels: - unscaled_loss = torch.nn.functional.cross_entropy( - noisy_logits.view(-1, noisy_logits.shape[-1]), - labels.view(-1), - reduction="none", - ) - loss = unscaled_loss.sum() / (labels != -100).sum() / len(micro_batches) - else: - shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() - shifted_labels = labels[:, 1:].contiguous() - unscaled_loss = torch.nn.functional.cross_entropy( - shifted_noisy_logits.view(-1, shifted_noisy_logits.shape[-1]), - shifted_labels.view(-1), - reduction="none", - ).view(shifted_noisy_logits.shape[0], -1) - loss = unscaled_loss.sum() / (shifted_labels != -100).sum() / len(micro_batches) - - with model_bwd_context: - loss.backward() - - total_loss += loss.item() - del micro_batch - - # Prefer model-provided clip_grad_norm_ (now both FSDP1 and FSDP2 registers custom grad norm clipping) - if hasattr(model, "clip_grad_norm_"): - _gn = model.clip_grad_norm_(args.train.max_grad_norm) - grad_norm = _gn.item() if hasattr(_gn, "item") else float(_gn) - else: - logger.info_rank0( - "Can NOT find regitsered clip_grad_norm_ method in the model, using PyTorch default implementation.." - ) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.train.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if hasattr(grad_norm, "full_tensor"): - grad_norm = grad_norm.full_tensor().item() - - # collect mean loss across data parallel group - total_loss, grad_norm = all_reduce((total_loss, grad_norm), group=get_parallel_state().fsdp_group) - synchronize() - delta_time = time.time() - start_time - lr = max(lr_scheduler.get_last_lr()) - train_metrics = environ_meter.step(delta_time, global_step=global_step) - - data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") - data_loader_tqdm.update() - - if args.train.global_rank == 0: - if args.train.use_wandb: - train_metrics.update( - {"training/loss": total_loss, "training/grad_norm": grad_norm, "training/lr": lr} - ) - wandb.log(train_metrics, step=global_step) - - if args.train.profile_this_rank and global_step <= args.train.profile_end_step: - profiler.step() - if global_step == args.train.profile_end_step: - profiler.stop() - - if args.train.save_steps and global_step % args.train.save_steps == 0: - helper.empty_cache() - save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") - state = { - "model": model, - "optimizer": optimizer, - "extra_state": { - "global_step": global_step, - "lr_scheduler": lr_scheduler.state_dict(), - "train_dataloader": train_dataloader.state_dict(), - "environ_meter": environ_meter.state_dict(), - "torch_rng_state": torch.get_rng_state(), - }, - } - Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) - - dist.barrier() - logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") - - data_loader_tqdm.close() - start_step = 0 - helper.print_device_mem_info(f"VRAM usage after epoch {epoch + 1}") - if args.train.save_epochs and (epoch + 1) % args.train.save_epochs == 0: - helper.empty_cache() - save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") - state = { - "model": model, - "optimizer": optimizer, - "extra_state": { - "global_step": global_step, - "lr_scheduler": lr_scheduler.state_dict(), - "train_dataloader": train_dataloader.state_dict(), - "environ_meter": environ_meter.state_dict(), - "torch_rng_state": torch.get_rng_state(), - }, - } - Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) - dist.barrier() - logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") - - synchronize() - # release memory - del optimizer, lr_scheduler - helper.empty_cache() - # save model in huggingface's format - if args.train.global_rank == 0 and args.train.save_hf_weights and save_checkpoint_path is not None: - hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt") - model_state_dict = ckpt_to_state_dict( - save_checkpoint_path=save_checkpoint_path, - output_dir=args.train.output_dir, - ckpt_manager=args.train.ckpt_manager, - ) - save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets) - logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!") - - dist.barrier() - dist.destroy_process_group() + run_llada2_training(LLaDA2Arguments) if __name__ == "__main__": diff --git a/tasks/train_llada2_bd_with_dparallel.py b/tasks/train_llada2_bd_with_dparallel.py index 2084cc5..f00a321 100644 --- a/tasks/train_llada2_bd_with_dparallel.py +++ b/tasks/train_llada2_bd_with_dparallel.py @@ -1,623 +1,9 @@ -import json -import os -import time -from dataclasses import asdict, dataclass, field -from functools import partial -from typing import Any, Dict, List, Literal, Tuple, Optional +from train_llada2_common import LLaDA2Arguments, run_llada2_training -import torch -import torch.nn.functional as F -import torch.distributed as dist -import wandb -from tqdm import trange -from veomni.checkpoint import build_checkpointer, ckpt_to_state_dict -from veomni.data import ( - build_dataloader, - build_iterative_dataset, - build_mapping_dataset, -) -from veomni.distributed.offloading import build_activation_offloading_context -from veomni.distributed.parallel_state import get_parallel_state, init_parallel_state -from veomni.distributed.torch_parallelize import build_parallelize_model -from veomni.models import build_foundation_model, build_tokenizer, save_model_assets, save_model_weights -from veomni.optim import build_lr_scheduler, build_optimizer -from veomni.utils import helper -from veomni.utils.arguments import DataArguments, ModelArguments, TrainingArguments, parse_args, save_args -from veomni.utils.device import ( - get_device_type, - get_nccl_backend, - get_torch_device, - synchronize, -) -from veomni.utils.dist_utils import all_reduce -from veomni.models.registry import ModelRegistry -ModelRegistry.register_modeling_path("models.llada2_moe") -from dataset.data_transform import process_mdm_tokenized_example, process_mdm_sft_example -from dataset import build_local_dataset - - -logger = helper.create_logger(__name__) - -@dataclass -class LLaDA2ModelArguments(ModelArguments): - attn_implementation: Optional[Literal["eager", "sdpa", "flex_attention"]] = field( - default="sdpa", - metadata={"help": "Attention implementation to use."}, - ) - - -@dataclass -class LLaDA2DataArguments(DataArguments): - data_type: Literal["conversation", "tokenid"] = field( - default="conversation", - metadata={"help": "Type of the training data."}, - ) - datasets_type: Literal["mapping", "local"] = field( - default="mapping", - metadata={"help": "Type of the datasets."}, - ) - text_keys: str = field( - default="messages", - metadata={"help": "Key to get text from the training data."}, - ) - noise_range_low: float = field( - default=0.3, - metadata={"help": "Noise level for random flip input_ids to mask_ids"} - ) - noise_range_high: float = field( - default=0.8, - metadata={"help": "Noise level for random flip input_ids to mask_ids"} - ) - - def __post_init__(self): - super().__post_init__() - if self.noise_range_low > self.noise_range_high: - raise ValueError( - f"noise_range_low ({self.noise_range_low}) " - f"cannot be greater than noise_range_high ({self.noise_range_high})." - ) - - if not (0.0 <= self.noise_range_low <= 1.0): - raise ValueError( - f"noise_range_low must be between 0.0 and 1.0, but got {self.noise_range_low}." - ) - - if not (0.0 <= self.noise_range_high <= 1.0): - raise ValueError( - f"noise_range_high must be between 0.0 and 1.0, but got {self.noise_range_high}." - ) - - -@dataclass -class LLaDA2TrainingArguments(TrainingArguments): - beta1: float = field( - default=0.9, - metadata={"help": "AdamW optimizer beta1."}, - ) - beta2: float = field( - default=0.999, - metadata={"help": "AdamW optimizer beta2"}, - ) - confidence_beta: float = field( - default=0.0, - metadata={"help": "Weight for the confidence loss entropy of correct predictions. Set to 0 to disable."}, - ) - block_diffusion_mode: bool = field( - default=False, - metadata={"help": "If train MDM in block_diffusion mode. True: use block_diffusion, False: full_attention"} - ) - block_size: int = field( - default=32, - metadata={"help": "The block size for block diffusion block size"} - ) - same_token_labels: bool = field( - default=False, - metadata={"help": "If use same token location labels. True: no shift, False: use next-token prediction shift."} - ) - - -@dataclass -class Arguments: - model: "LLaDA2ModelArguments" = field(default_factory=LLaDA2ModelArguments) - data: "LLaDA2DataArguments" = field(default_factory=LLaDA2DataArguments) - train: "LLaDA2TrainingArguments" = field(default_factory=LLaDA2TrainingArguments) - - -def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): - """ - Constructs the specialized block diffusion attention mask for training - composed of three masks: - - **Block Diagonal Mask (M_BD)**: Self-attention within noised blocks - - **Offset Block Causal Mask (M_OBC)**: Cross-attention for conditional context - - **Block Causal Mask (M_BC)**: Attention to update x0 - - Args: - b, h: Batch and head indices (ignored for mask logic). - q_idx, kv_idx: Query and Key indices. - seq_len: Total sequence length. - block_size: Defines the block structure. - - Returns: - A boolean attention mask. - """ - - # Indicate whether token belongs to xt or x0 - x0_flag_q = (q_idx >= n) - x0_flag_kv = (kv_idx >= n) - - # Compute block indices - block_q = torch.where(x0_flag_q == 1, - (q_idx - n) // block_size, - q_idx // block_size) - block_kv = torch.where(x0_flag_kv == 1, - (kv_idx - n) // block_size, - kv_idx // block_size) - - # **1. Block Diagonal Mask (M_BD) ** - block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) - - # **2. Offset Block-Causal Mask (M_OBC) ** - offset_block_causal = ( - (block_q > block_kv) - & (x0_flag_kv == 1) - & (x0_flag_q == 0) - ) - - # **3. Block-Causal Mask (M_BC) ** - block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) - - # **4. Combine Masks ** - return block_diagonal | offset_block_causal | block_causal - -def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: - """ - Calculate the average entropy of the output distribution at positions where the model predicts correctly. - Args: - logits (torch.Tensor): The raw output logits from the model, with shape (batch_size, seq_len, vocab_size). - labels (torch.Tensor): The ground truth labels, with shape (batch_size, seq_len). -100 indicates positions to be ignored. - Returns: - torch.Tensor: A scalar tensor representing the confidence loss. Returns 0 if there are no correct predictions. - """ - labels = labels.to(logits.device) - - valid_mask = (labels != -100) - if not valid_mask.any(): - return torch.tensor(0.0, device=logits.device) - - predicted_tokens = torch.argmax(logits, dim=-1) - - correct_mask = (predicted_tokens == labels) & valid_mask - - if correct_mask.sum() == 0: - return torch.tensor(0.0, device=logits.device) - - log_probs = F.log_softmax(logits, dim=-1) - probs = torch.exp(log_probs) - entropy_per_token = -torch.sum(probs * log_probs, dim=-1) - - entropy_at_correct_positions = entropy_per_token[correct_mask] - - confidence_loss = entropy_at_correct_positions.mean() - - return confidence_loss - def main(): - dist.init_process_group(backend=get_nccl_backend()) - args = parse_args(Arguments) - logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}") - logger.info_rank0(json.dumps(asdict(args), indent=2)) - get_torch_device().set_device(f"{get_device_type()}:{args.train.local_rank}") - helper.set_seed(args.train.seed, args.train.enable_full_determinism) - if args.train.local_rank == 0: - helper.enable_third_party_logging() - - if args.train.global_rank == 0: - save_args(args, args.train.output_dir) - - Checkpointer = build_checkpointer(dist_backend=args.train.data_parallel_mode, ckpt_manager=args.train.ckpt_manager) - - init_parallel_state( - dp_size=args.train.data_parallel_size, - dp_replicate_size=args.train.data_parallel_replicate_size, - dp_shard_size=args.train.data_parallel_shard_size, - tp_size=args.train.tensor_parallel_size, - ep_size=args.train.expert_parallel_size, - pp_size=args.train.pipeline_parallel_size, - cp_size=args.train.context_parallel_size, - ulysses_size=args.train.ulysses_parallel_size, - dp_mode=args.train.data_parallel_mode, - ) - - logger.info_rank0("Prepare data") - tokenizer = build_tokenizer(args.model.tokenizer_path) - if args.data.data_type == "conversation": - if not tokenizer.chat_template: - raise ValueError(f"No chat template found in the tokenizer.") - - transform = partial( - process_mdm_sft_example, - tokenizer=tokenizer, - max_seq_len=args.data.max_seq_len, - text_keys=args.data.text_keys, - noise_range=(args.data.noise_range_low, args.data.noise_range_high), - mask_token_id=156895, - ) - elif args.data.data_type == "tokenid": - transform = partial( - process_mdm_tokenized_example, - max_seq_len=args.data.max_seq_len, - text_keys=args.data.text_keys, - noise_range=(args.data.noise_range_low, args.data.noise_range_high), - mask_token_id=156895, - ) - else: - raise NotImplementedError(f"Unsupported data type: {args.data.data_type}.") - - if args.data.dataloader_type == "native": - if args.data.datasets_type == "iterable": - logger.info_rank0("Start building iterative dataset") - train_dataset = build_iterative_dataset(args.data.train_path, transform=transform, seed=args.train.seed) - elif args.data.datasets_type == "mapping": - logger.info_rank0("Start building mapping dataset") - train_dataset = build_mapping_dataset(args.data.train_path, transform=transform) - elif args.data.datasets_type == "local": - logger.info_rank0("Start building local dataset") - train_dataset = build_local_dataset(args.data.train_path, transform=transform) - - dataset_length = None if not hasattr(train_dataset, "__len__") else len(train_dataset) - if args.data.datasets_type == "mapping" or args.data.datasets_type == "local": - dataset_length = dataset_length / args.train.data_parallel_size - args.train.compute_train_steps(args.data.max_seq_len, args.data.train_size, dataset_length) - - train_dataloader = build_dataloader( - dataset=train_dataset, - micro_batch_size=args.train.micro_batch_size, - global_batch_size=args.train.global_batch_size, - dataloader_batch_size=args.train.dataloader_batch_size, - seed=args.train.seed, - max_seq_len=args.data.max_seq_len, - train_steps=args.train.train_steps, - rmpad=args.train.rmpad, - rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, - bsz_warmup_ratio=args.train.bsz_warmup_ratio, - bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken, - dyn_bsz_margin=args.train.dyn_bsz_margin, - dyn_bsz_buffer_size=args.train.dyn_bsz_buffer_size, - num_workers=args.data.num_workers, - drop_last=args.data.drop_last, - pin_memory=args.data.pin_memory, - prefetch_factor=args.data.prefetch_factor, - ) - else: - raise NotImplementedError(f"Unsupported dataloader type: {args.data.dataloader_type}.") - - logger.info_rank0("Prepare model") - model = build_foundation_model( - config_path=args.model.config_path, - weights_path=args.model.model_path, - torch_dtype="float32" if args.train.enable_mixed_precision else "bfloat16", - attn_implementation=args.model.attn_implementation, - moe_implementation=args.model.moe_implementation, - init_device=args.train.init_device, - force_use_huggingface=args.model.force_use_huggingface, - ) - model_config = model.config - helper.print_device_mem_info("VRAM usage after building model") - - get_optimizer_pre_hook = getattr(model, "get_optimizer_pre_hook", None) - model = build_parallelize_model( - model, - init_device=args.train.init_device, - weights_path=args.model.model_path, - enable_full_shard=args.train.enable_full_shard, - enable_mixed_precision=args.train.enable_mixed_precision, - enable_gradient_checkpointing=args.train.enable_gradient_checkpointing, - enable_fsdp_offload=args.train.enable_fsdp_offload, - basic_modules=model._no_split_modules + args.model.basic_modules, - enable_reentrant=args.train.enable_reentrant, - enable_forward_prefetch=args.train.enable_forward_prefetch, - broadcast_model_weights_from_rank0=args.train.broadcast_model_weights_from_rank0 - ) - - optimizer = build_optimizer( - model, - lr=args.train.lr, - betas=(args.train.beta1, args.train.beta2), - weight_decay=args.train.weight_decay, - fused=True, - optimizer_type=args.train.optimizer, - ) - - if get_optimizer_pre_hook is not None: - optimizer_pre_hook = get_optimizer_pre_hook(model, model_config, args.train.data_parallel_mode) - optimizer.register_step_pre_hook(optimizer_pre_hook) - - lr_scheduler = build_lr_scheduler( - optimizer, - train_steps=args.train.train_steps * args.train.num_train_epochs, - lr=args.train.lr, - lr_min=args.train.lr_min, - lr_decay_style=args.train.lr_decay_style, - lr_decay_ratio=args.train.lr_decay_ratio, - lr_warmup_ratio=args.train.lr_warmup_ratio, - lr_start=args.train.lr_start, - ) - - if args.train.global_rank == 0: - if args.train.use_wandb: - wandb.init( - project=args.train.wandb_project, - name=args.train.wandb_name, - config={**vars(args.model), **vars(args.data), **vars(args.train)}, # flatten dict - ) - - # save model_assets before training - model_assets = [model_config, tokenizer] - save_model_assets(args.train.model_assets_dir, model_assets) - - if args.train.profile_this_rank: - profiler = helper.create_profiler( - start_step=args.train.profile_start_step, - end_step=args.train.profile_end_step, - trace_dir=args.train.profile_trace_dir, - record_shapes=args.train.profile_record_shapes, - profile_memory=args.train.profile_profile_memory, - with_stack=args.train.profile_with_stack, - global_rank=args.train.global_rank, - ) - profiler.start() - - start_epoch, start_step, global_step = 0, 0, 0 - save_checkpoint_path = None - environ_meter = helper.EnvironMeter( - config=model_config, - global_batch_size=args.train.global_batch_size, - rmpad=args.train.rmpad, - rmpad_with_pos_ids=args.train.rmpad_with_pos_ids, - empty_cache_steps=args.train.empty_cache_steps, - enable_multisource=args.data.enable_multisource, - dataloader=train_dataloader, - data_path=args.data.train_path, - ) - - if args.train.load_checkpoint_path: - state = {"model": model, "optimizer": optimizer, "extra_state": {}} # cannot be None - Checkpointer.load(args.train.load_checkpoint_path, state) - global_step = state["extra_state"]["global_step"] - start_epoch = global_step // args.train.train_steps - start_step = global_step % args.train.train_steps - lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"]) - train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) - environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) - torch.set_rng_state(state["extra_state"]["torch_rng_state"]) - if start_step == 0: # resume at the end of epoch - iter(train_dataloader) # clear resume state and prefetch data - - dist.barrier() - logger.info_rank0(f"Load distributed checkpoint from {args.train.load_checkpoint_path} successfully!") - - # Build block diffusion attention mask - if args.train.block_diffusion_mode: - bd_attn_full_len = args.data.max_seq_len * 2 - block_size = args.train.block_size - # NOTE: Boolean dtype block diffusion attention mask - block_diffusion_attn_mask_flag = block_diffusion_mask( - b=None, h=None, - q_idx=torch.arange(bd_attn_full_len)[:, None], - kv_idx=torch.arange(bd_attn_full_len)[None, :], - block_size=block_size, - n=args.data.max_seq_len - ).unsqueeze(0).unsqueeze(0) - - block_diffusion_attn_mask_prototype = torch.zeros_like( - block_diffusion_attn_mask_flag, - dtype=torch.float32 if args.train.enable_mixed_precision else torch.bfloat16 - ) - block_diffusion_attn_mask_prototype.masked_fill_(block_diffusion_attn_mask_flag.logical_not(), float("-inf")) - - helper.empty_cache() - model_fwd_context, model_bwd_context = build_activation_offloading_context( - args.train.enable_activation_offload, args.train.enable_gradient_checkpointing, args.train.activation_gpu_limit - ) - model.train() - logger.info( - f"rank{args.train.local_rank} Start training, train_steps: {args.train.train_steps}, epochs: {args.train.num_train_epochs}" - ) - for epoch in range(start_epoch, args.train.num_train_epochs): - if hasattr(train_dataloader, "set_epoch"): - train_dataloader.set_epoch(epoch) - - data_loader_tqdm = trange( - args.train.train_steps, - desc=f"Epoch {epoch + 1}/{args.train.num_train_epochs}", - total=args.train.train_steps, - initial=start_step, - disable=args.train.local_rank != 0, - ) - data_iterator = iter(train_dataloader) - for _ in range(start_step, args.train.train_steps): - global_step += 1 - - try: - micro_batches: List[Dict[str, Any]] = next(data_iterator) - except StopIteration: - logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.drop_last}") - break - - if global_step == 1: - helper.print_example(example=micro_batches[0], rank=args.train.local_rank) - - total_loss = 0 - synchronize() - start_time = time.time() - num_accumulation_steps = len(micro_batches) - total_consistency_loss = 0 - total_confidence_loss = 0 - - for micro_batch in micro_batches: - environ_meter.add(micro_batch) - if args.data.enable_multisource: - micro_batch.pop("ds_idx", None) - micro_batch.pop("source_name", None) - - micro_batch = { - k: v.to(get_device_type(), non_blocking=True) if isinstance(v, torch.Tensor) else v - for k, v in micro_batch.items() - } - if args.train.block_diffusion_mode: - noisy_input_ids = micro_batch["noisy_input_ids"] - clean_input_ids = micro_batch["input_ids"] - batch_size = noisy_input_ids.shape[0] - full_input_ids = torch.cat([noisy_input_ids, clean_input_ids], dim=1) - noisy_position_ids = torch.arange(noisy_input_ids.shape[1], device=get_device_type(), dtype=torch.long) - clean_position_ids = torch.arange(clean_input_ids.shape[1], device=get_device_type(), dtype=torch.long) - position_ids = torch.cat([noisy_position_ids, clean_position_ids], dim=0).unsqueeze(0).expand(batch_size, -1).clone() - micro_batch["input_ids"] = full_input_ids - micro_batch["position_ids"] = position_ids - micro_batch["attention_mask"] = block_diffusion_attn_mask_prototype.expand(batch_size, -1, -1, -1) - else: - micro_batch["attention_mask"] = None - - labels = micro_batch.pop("labels", None) - - with model_fwd_context: - logits: "torch.Tensor" = model(**micro_batch, use_cache=False, output_router_logits=False).logits - if args.train.block_diffusion_mode: - noisy_logits = logits[:, :noisy_input_ids.shape[1]].contiguous() - else: - noisy_logits = logits - - confidence_loss = torch.tensor(0.0, device=noisy_logits.device) - if args.train.confidence_beta > 0: - confidence_loss = compute_confidence_loss( - logits=noisy_logits, - labels=labels, - ) - - if args.train.same_token_labels: - unscaled_loss = torch.nn.functional.cross_entropy( - noisy_logits.view(-1, noisy_logits.shape[-1]), - labels.view(-1), - reduction="none", - ).view(noisy_logits.shape[0], -1) - consistency_loss = unscaled_loss.sum() / (labels != -100).sum() - else: - shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() - shifted_labels = labels[:, 1:].contiguous() - unscaled_loss = torch.nn.functional.cross_entropy( - shifted_noisy_logits.view(-1, shifted_noisy_logits.shape[-1]), - shifted_labels.view(-1), - reduction="none", - ).view(shifted_noisy_logits.shape[0], -1) - consistency_loss = unscaled_loss.sum() / (shifted_labels != -100).sum() - - combined_loss = consistency_loss + confidence_loss * args.train.confidence_beta - loss = combined_loss / num_accumulation_steps - with model_bwd_context: - loss.backward() - - total_loss += loss.item() - total_consistency_loss += consistency_loss.item() / num_accumulation_steps - total_confidence_loss += confidence_loss.item() / num_accumulation_steps - del micro_batch - - # Prefer model-provided clip_grad_norm_ (now both FSDP1 and FSDP2 registers custom grad norm clipping) - if hasattr(model, "clip_grad_norm_"): - _gn = model.clip_grad_norm_(args.train.max_grad_norm) - grad_norm = _gn.item() if hasattr(_gn, "item") else float(_gn) - else: - logger.info_rank0( - "Can NOT find regitsered clip_grad_norm_ method in the model, using PyTorch default implementation.." - ) - grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.train.max_grad_norm) - - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - if hasattr(grad_norm, "full_tensor"): - grad_norm = grad_norm.full_tensor().item() - - # collect mean loss across data parallel group - total_loss, grad_norm = all_reduce((total_loss, grad_norm), group=get_parallel_state().fsdp_group) - synchronize() - delta_time = time.time() - start_time - lr = max(lr_scheduler.get_last_lr()) - train_metrics = environ_meter.step(delta_time, global_step=global_step) - - data_loader_tqdm.set_postfix_str(f"loss: {total_loss:.2f}, cons: {total_consistency_loss:.2f}, conf: {total_confidence_loss:.2f}, grad_norm: {grad_norm:.2f}, lr: {lr:.2e}") - data_loader_tqdm.update() - - if args.train.global_rank == 0: - if args.train.use_wandb: - train_metrics.update( - {"training/loss": total_loss, "training/cons_loss": total_consistency_loss, "training/conf_loss": total_confidence_loss, "training/grad_norm": grad_norm, "training/lr": lr} - ) - wandb.log(train_metrics, step=global_step) - - if args.train.profile_this_rank and global_step <= args.train.profile_end_step: - profiler.step() - if global_step == args.train.profile_end_step: - profiler.stop() - - if args.train.save_steps and global_step % args.train.save_steps == 0: - helper.empty_cache() - save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") - state = { - "model": model, - "optimizer": optimizer, - "extra_state": { - "global_step": global_step, - "lr_scheduler": lr_scheduler.state_dict(), - "train_dataloader": train_dataloader.state_dict(), - "environ_meter": environ_meter.state_dict(), - "torch_rng_state": torch.get_rng_state(), - }, - } - Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) - - dist.barrier() - logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") - - data_loader_tqdm.close() - start_step = 0 - helper.print_device_mem_info(f"VRAM usage after epoch {epoch + 1}") - if args.train.save_epochs and (epoch + 1) % args.train.save_epochs == 0: - helper.empty_cache() - save_checkpoint_path = os.path.join(args.train.save_checkpoint_path, f"global_step_{global_step}") - state = { - "model": model, - "optimizer": optimizer, - "extra_state": { - "global_step": global_step, - "lr_scheduler": lr_scheduler.state_dict(), - "train_dataloader": train_dataloader.state_dict(), - "environ_meter": environ_meter.state_dict(), - "torch_rng_state": torch.get_rng_state(), - }, - } - Checkpointer.save(args.train.save_checkpoint_path, state, global_steps=global_step) - dist.barrier() - logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") - - synchronize() - # release memory - del optimizer, lr_scheduler - helper.empty_cache() - # save model in huggingface's format - if args.train.global_rank == 0 and args.train.save_hf_weights and save_checkpoint_path is not None: - hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt") - model_state_dict = ckpt_to_state_dict( - save_checkpoint_path=save_checkpoint_path, - output_dir=args.train.output_dir, - ckpt_manager=args.train.ckpt_manager, - ) - save_model_weights(hf_weights_path, model_state_dict, model_assets=model_assets) - logger.info_rank0(f"Huggingface checkpoint saved at {hf_weights_path} successfully!") - - dist.barrier() - dist.destroy_process_group() + run_llada2_training(LLaDA2Arguments) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/tasks/train_llada2_common.py b/tasks/train_llada2_common.py new file mode 100644 index 0000000..a9618af --- /dev/null +++ b/tasks/train_llada2_common.py @@ -0,0 +1,663 @@ +import json +import os +import time +from dataclasses import asdict, dataclass, field +from datetime import timedelta +from functools import partial +from typing import Any, Dict, List, Literal, Optional, Tuple + +import torch +import torch.distributed as dist +import torch.nn.functional as F +import wandb +from torch.utils.checkpoint import set_checkpoint_debug_enabled +from tqdm import trange + +import models.llada2_moe # noqa: F401 - registers LLaDA2 MoE with the VeOmni loader. +from veomni.arguments import DataArguments, ModelArguments, TrainingArguments, VeOmniArguments, parse_args, save_args +from veomni.checkpoint import build_checkpointer +from veomni.data import build_dataloader, build_dataset +from veomni.distributed.clip_grad_norm import veomni_clip_grad_norm +from veomni.distributed.offloading import build_activation_offloading_context +from veomni.distributed.parallel_state import get_parallel_state, init_parallel_state +from veomni.distributed.torch_parallelize import build_parallelize_model +from veomni.models import build_foundation_model, build_tokenizer, save_model_assets +from veomni.optim import build_lr_scheduler, build_optimizer +from veomni.utils import helper +from veomni.utils.device import ( + get_device_type, + get_dist_comm_backend, + get_torch_device, + is_nccl_backend, + synchronize, +) +from veomni.utils.dist_utils import all_reduce +from veomni.utils.save_safetensor_utils import save_hf_safetensor + +try: + from dataset import build_local_dataset + from dataset.data_transform import process_mdm_sft_example, process_mdm_tokenized_example +except ImportError: + from tasks.dataset import build_local_dataset + from tasks.dataset.data_transform import process_mdm_sft_example, process_mdm_tokenized_example + + +logger = helper.create_logger(__name__) + + +@dataclass +class LLaDA2ModelArguments(ModelArguments): + attn_implementation: Optional[Literal["eager", "sdpa", "flex_attention"]] = field( + default=None, + metadata={"help": "Deprecated. Use model.ops_implementation.attn_implementation."}, + ) + moe_implementation: Optional[str] = field( + default=None, + metadata={"help": "Deprecated. Use model.ops_implementation.moe_implementation."}, + ) + + def __post_init__(self): + super().__post_init__() + if self.attn_implementation is not None: + self.ops_implementation.attn_implementation = self.attn_implementation + if self.moe_implementation is not None: + self.ops_implementation.moe_implementation = self.moe_implementation + + +@dataclass +class LLaDA2DataArguments(DataArguments): + data_type: Literal["conversation", "tokenid"] = field( + default="conversation", + metadata={"help": "Type of the training data."}, + ) + datasets_type: Literal["mapping", "iterable", "local"] = field( + default="mapping", + metadata={"help": "Type of the datasets."}, + ) + text_keys: Optional[str] = field( + default=None, + metadata={"help": "Key to get text or token ids from the training data."}, + ) + noise_range_low: float = field( + default=0.3, + metadata={"help": "Lower bound of random mask noise ratio."}, + ) + noise_range_high: float = field( + default=0.8, + metadata={"help": "Upper bound of random mask noise ratio."}, + ) + mask_token_id: int = field( + default=156895, + metadata={"help": "LLaDA2 mask token id."}, + ) + + def __post_init__(self): + if self.text_keys is None: + self.text_keys = "input_ids" if self.data_type == "tokenid" else "messages" + super().__post_init__() + if self.noise_range_low > self.noise_range_high: + raise ValueError( + f"noise_range_low ({self.noise_range_low}) cannot be greater than " + f"noise_range_high ({self.noise_range_high})." + ) + if not (0.0 <= self.noise_range_low <= 1.0): + raise ValueError(f"noise_range_low must be between 0.0 and 1.0, but got {self.noise_range_low}.") + if not (0.0 <= self.noise_range_high <= 1.0): + raise ValueError(f"noise_range_high must be between 0.0 and 1.0, but got {self.noise_range_high}.") + + +@dataclass +class LLaDA2TrainingArguments(TrainingArguments): + beta1: float = field( + default=0.9, + metadata={"help": "AdamW optimizer beta1."}, + ) + beta2: float = field( + default=0.999, + metadata={"help": "AdamW optimizer beta2."}, + ) + confidence_beta: float = field( + default=0.0, + metadata={"help": "Weight for the confidence loss entropy of correct predictions. Set to 0 to disable."}, + ) + block_diffusion_mode: bool = field( + default=False, + metadata={"help": "Train MDM in block diffusion mode."}, + ) + block_size: int = field( + default=32, + metadata={"help": "Block size for block diffusion."}, + ) + same_token_labels: bool = field( + default=False, + metadata={"help": "Use same token labels instead of next-token shifted labels."}, + ) + + +@dataclass +class LLaDA2Arguments(VeOmniArguments): + model: LLaDA2ModelArguments = field(default_factory=LLaDA2ModelArguments) + data: LLaDA2DataArguments = field(default_factory=LLaDA2DataArguments) + train: LLaDA2TrainingArguments = field(default_factory=LLaDA2TrainingArguments) + + +def block_diffusion_mask(b, h, q_idx, kv_idx, block_size=None, n=None): + del b, h + x0_flag_q = q_idx >= n + x0_flag_kv = kv_idx >= n + + block_q = torch.where(x0_flag_q == 1, (q_idx - n) // block_size, q_idx // block_size) + block_kv = torch.where(x0_flag_kv == 1, (kv_idx - n) // block_size, kv_idx // block_size) + + block_diagonal = (block_q == block_kv) & (x0_flag_q == x0_flag_kv) + offset_block_causal = (block_q > block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 0) + block_causal = (block_q >= block_kv) & (x0_flag_kv == 1) & (x0_flag_q == 1) + return block_diagonal | offset_block_causal | block_causal + + +def compute_confidence_loss(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: + labels = labels.to(logits.device) + valid_mask = labels != -100 + if not valid_mask.any(): + return torch.tensor(0.0, device=logits.device) + + predicted_tokens = torch.argmax(logits, dim=-1) + correct_mask = (predicted_tokens == labels) & valid_mask + if correct_mask.sum() == 0: + return torch.tensor(0.0, device=logits.device) + + log_probs = F.log_softmax(logits, dim=-1) + probs = torch.exp(log_probs) + entropy_per_token = -torch.sum(probs * log_probs, dim=-1) + return entropy_per_token[correct_mask].mean() + + +def _build_transform(args: LLaDA2Arguments, tokenizer): + noise_range = (args.data.noise_range_low, args.data.noise_range_high) + if args.data.data_type == "conversation": + if not tokenizer.chat_template: + raise ValueError("No chat template found in the tokenizer.") + return partial( + process_mdm_sft_example, + tokenizer=tokenizer, + max_seq_len=args.data.max_seq_len, + text_keys=args.data.text_keys, + noise_range=noise_range, + mask_token_id=args.data.mask_token_id, + ) + if args.data.data_type == "tokenid": + return partial( + process_mdm_tokenized_example, + max_seq_len=args.data.max_seq_len, + text_keys=args.data.text_keys, + noise_range=noise_range, + mask_token_id=args.data.mask_token_id, + ) + raise NotImplementedError(f"Unsupported data type: {args.data.data_type}.") + + +def _build_train_dataset(args: LLaDA2Arguments, transform): + if args.data.datasets_type == "local": + return build_local_dataset(args.data.train_path, transform=transform, seed=args.train.seed) + + return build_dataset( + dataset_name=args.data.dataset_name, + transform=transform, + dataloader_batch_size=args.train.dataloader_batch_size, + seed=args.train.seed, + **asdict(args.data), + ) + + +def _build_block_diffusion_mask(args: LLaDA2Arguments) -> Optional[torch.Tensor]: + if not args.train.block_diffusion_mode: + return None + + full_len = args.data.max_seq_len * 2 + mask_flag = block_diffusion_mask( + b=None, + h=None, + q_idx=torch.arange(full_len)[:, None], + kv_idx=torch.arange(full_len)[None, :], + block_size=args.train.block_size, + n=args.data.max_seq_len, + ).unsqueeze(0).unsqueeze(0) + + mask_dtype = torch.float32 if args.train.accelerator.fsdp_config.mixed_precision.enable else torch.bfloat16 + mask = torch.zeros_like(mask_flag, dtype=mask_dtype) + mask.masked_fill_(mask_flag.logical_not(), float("-inf")) + return mask + + +def _prepare_micro_batch( + args: LLaDA2Arguments, + micro_batch: Dict[str, Any], + block_diffusion_attn_mask: Optional[torch.Tensor], +) -> Tuple[Dict[str, Any], int]: + if args.train.block_diffusion_mode: + noisy_input_ids = micro_batch.pop("noisy_input_ids") + clean_input_ids = micro_batch["input_ids"] + batch_size = noisy_input_ids.shape[0] + noisy_seq_len = noisy_input_ids.shape[1] + + full_input_ids = torch.cat([noisy_input_ids, clean_input_ids], dim=1) + noisy_position_ids = torch.arange(noisy_seq_len, device=full_input_ids.device, dtype=torch.long) + clean_position_ids = torch.arange(clean_input_ids.shape[1], device=full_input_ids.device, dtype=torch.long) + position_ids = torch.cat([noisy_position_ids, clean_position_ids], dim=0).unsqueeze(0) + + micro_batch["input_ids"] = full_input_ids + micro_batch["position_ids"] = position_ids.expand(batch_size, -1).clone() + micro_batch["attention_mask"] = block_diffusion_attn_mask.expand(batch_size, -1, -1, -1) + else: + noisy_seq_len = micro_batch["input_ids"].shape[1] + micro_batch.pop("noisy_input_ids", None) + micro_batch["attention_mask"] = None + + micro_batch = { + k: v.to(get_device_type(), non_blocking=True) if isinstance(v, torch.Tensor) else v + for k, v in micro_batch.items() + } + return micro_batch, noisy_seq_len + + +def _compute_llada2_loss( + args: LLaDA2Arguments, + noisy_logits: torch.Tensor, + labels: torch.Tensor, + num_micro_steps: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + confidence_loss = torch.tensor(0.0, device=noisy_logits.device) + if args.train.confidence_beta > 0: + confidence_loss = compute_confidence_loss(logits=noisy_logits, labels=labels) + + if args.train.same_token_labels: + unscaled_loss = F.cross_entropy( + noisy_logits.view(-1, noisy_logits.shape[-1]), + labels.view(-1), + reduction="none", + ) + denom = (labels != -100).sum().clamp_min(1) + consistency_loss = unscaled_loss.sum() / denom + else: + shifted_noisy_logits = noisy_logits[:, :-1, :].contiguous() + shifted_labels = labels[:, 1:].contiguous() + unscaled_loss = F.cross_entropy( + shifted_noisy_logits.view(-1, shifted_noisy_logits.shape[-1]), + shifted_labels.view(-1), + reduction="none", + ) + denom = (shifted_labels != -100).sum().clamp_min(1) + consistency_loss = unscaled_loss.sum() / denom + + combined_loss = consistency_loss + confidence_loss * args.train.confidence_beta + return combined_loss / num_micro_steps, consistency_loss, confidence_loss + + +def run_llada2_training(arguments_cls=LLaDA2Arguments): + nccl_timeout = os.getenv("NCCL_TIMEOUT", None) + pg_nccl_timeout = None + if nccl_timeout is not None and is_nccl_backend(): + pg_nccl_timeout = timedelta(seconds=int(nccl_timeout)) + logger.info(f"Process_group timeout: {nccl_timeout}") + dist.init_process_group(backend=get_dist_comm_backend(), timeout=pg_nccl_timeout) + + args = parse_args(arguments_cls) + logger.info(f"Process rank: {args.train.global_rank}, world size: {args.train.world_size}") + logger.info_rank0(json.dumps(asdict(args), indent=2)) + get_torch_device().set_device(f"{get_device_type()}:{args.train.local_rank}") + helper.set_seed(args.train.seed, args.train.enable_full_determinism) + helper.enable_high_precision_for_bf16() + if args.train.local_rank == 0: + helper.enable_third_party_logging() + + if args.train.global_rank == 0: + save_args(args, args.train.checkpoint.output_dir) + + set_checkpoint_debug_enabled(args.train.gradient_checkpointing.debug) + + Checkpointer = build_checkpointer( + dist_backend=args.train.accelerator.fsdp_config.fsdp_mode, + ckpt_manager=args.train.checkpoint.manager, + ) + + init_parallel_state( + dp_size=args.train.accelerator.dp_size, + dp_replicate_size=args.train.accelerator.dp_replicate_size, + dp_shard_size=args.train.accelerator.dp_shard_size, + tp_size=args.train.accelerator.tp_size, + pp_size=args.train.accelerator.pp_size, + cp_size=args.train.accelerator.cp_size, + extra_parallel_sizes=args.train.accelerator.extra_parallel_sizes, + extra_parallel_placement_innermost=args.train.accelerator.extra_parallel_placement_innermost, + extra_parallel_names=args.train.accelerator.extra_parallel_names, + ulysses_size=args.train.accelerator.ulysses_size, + dp_mode=args.train.accelerator.fsdp_config.fsdp_mode, + ) + + logger.info_rank0("Prepare data") + tokenizer = build_tokenizer(args.model.tokenizer_path) + transform = _build_transform(args, tokenizer) + train_dataset = _build_train_dataset(args, transform) + dataset_length = None if not hasattr(train_dataset, "__len__") else len(train_dataset) + if args.data.datasets_type in ("mapping", "local") and dataset_length is not None: + dataset_length = dataset_length / args.train.accelerator.dp_size + args.compute_train_steps(dataset_length) + + train_dataloader = build_dataloader( + dataloader_type=args.data.dataloader.type, + dataset=train_dataset, + micro_batch_size=args.train.micro_batch_size, + global_batch_size=args.train.global_batch_size, + dataloader_batch_size=args.train.dataloader_batch_size, + max_seq_len=args.data.max_seq_len, + train_steps=args.train_steps, + dyn_bsz=args.train.dyn_bsz, + dyn_bsz_runtime=args.train.dyn_bsz_runtime, + dyn_bsz_count_mode=args.train.dyn_bsz_count_mode, + dyn_bsz_physical_overflow_ratio=args.train.dyn_bsz_physical_overflow_ratio, + dyn_bsz_buffer_size=args.data.dyn_bsz_buffer_size, + bsz_warmup_ratio=args.train.bsz_warmup_ratio, + bsz_warmup_init_mbtoken=args.train.bsz_warmup_init_mbtoken, + num_workers=args.data.dataloader.num_workers, + worker_num_threads=args.data.dataloader.worker_num_threads, + drop_last=args.data.dataloader.drop_last, + pin_memory=args.data.dataloader.pin_memory, + prefetch_factor=args.data.dataloader.prefetch_factor, + seed=args.train.seed, + collate_fn_kwargs={"pad_to_length": args.train.pad_to_length}, + save_steps=args.train.checkpoint.save_steps, + ) + + logger.info_rank0("Prepare model") + model = build_foundation_model( + config_path=args.model.config_path, + weights_path=args.model.model_path, + torch_dtype="float32" if args.train.accelerator.fsdp_config.mixed_precision.enable else "bfloat16", + init_device=args.train.init_device, + ops_implementation=args.model.ops_implementation, + ) + model_config = model.config + helper.print_device_mem_info("VRAM usage after building model") + + get_optimizer_pre_hook = getattr(model, "get_optimizer_pre_hook", None) + basic_modules = list(set(getattr(model, "_no_split_modules", None) or []) | set(args.model.basic_modules)) + model = build_parallelize_model( + model, + init_device=args.train.init_device, + weights_path=args.model.model_path, + enable_reshard_after_forward=args.train.accelerator.fsdp_config.reshard_after_forward, + mixed_precision=args.train.accelerator.fsdp_config.mixed_precision, + enable_gradient_checkpointing=args.train.gradient_checkpointing.enable, + basic_modules=basic_modules, + enable_reentrant=args.train.gradient_checkpointing.enable_reentrant, + enable_forward_prefetch=args.train.accelerator.fsdp_config.forward_prefetch, + ) + + optimizer = build_optimizer( + model, + lr=args.train.optimizer.lr, + betas=(args.train.beta1, args.train.beta2), + weight_decay=args.train.optimizer.weight_decay, + fused=True, + optimizer_type=args.train.optimizer.type, + no_decay_modules=args.train.optimizer.no_decay_modules, + no_decay_params=args.train.optimizer.no_decay_params, + ) + if get_optimizer_pre_hook is not None: + optimizer_pre_hook = get_optimizer_pre_hook(model, model_config, args.train.accelerator.fsdp_config.fsdp_mode) + optimizer.register_step_pre_hook(optimizer_pre_hook) + + lr_scheduler = build_lr_scheduler( + optimizer, + train_steps=args.train_steps * args.train.num_train_epochs, + lr=args.train.optimizer.lr, + lr_min=args.train.optimizer.lr_min, + lr_decay_style=args.train.optimizer.lr_decay_style, + lr_decay_ratio=args.train.optimizer.lr_decay_ratio, + lr_warmup_ratio=args.train.optimizer.lr_warmup_ratio, + lr_start=args.train.optimizer.lr_start, + ) + + model_assets = None + if args.train.global_rank == 0: + if args.train.wandb.enable: + wandb.init( + project=args.train.wandb.project, + name=args.train.wandb.name, + id=args.train.wandb.id, + resume="allow" if args.train.wandb.id else None, + settings=wandb.Settings(console="off"), + config={**vars(args.model), **vars(args.data), **vars(args.train)}, + ) + + model_assets = [model_config, tokenizer] + save_model_assets(args.train.checkpoint.model_assets_dir, model_assets) + + if args.train.profile.this_rank: + profiler = helper.create_profiler( + start_step=args.train.profile.start_step, + end_step=args.train.profile.end_step, + trace_dir=args.train.profile.trace_dir, + record_shapes=args.train.profile.record_shapes, + profile_memory=args.train.profile.profile_memory, + with_stack=args.train.profile.with_stack, + with_modules=args.train.profile.with_modules, + global_rank=args.train.global_rank, + ) + profiler.start() + + start_epoch, start_step, global_step = 0, 0, 0 + save_checkpoint_path = None + environ_meter = helper.EnvironMeter( + config=model_config, + global_batch_size=args.train.global_batch_size, + empty_cache_steps=args.train.empty_cache_steps, + enable_multisource=args.data.enable_multisource, + dataloader=train_dataloader, + data_path=args.data.train_path, + ) + + if args.train.checkpoint.load_path: + state = {"model": model, "optimizer": optimizer, "extra_state": {}} + Checkpointer.load(args.train.checkpoint.load_path, state) + global_step = state["extra_state"]["global_step"] + start_epoch = global_step // args.train_steps + start_step = global_step % args.train_steps + lr_scheduler.load_state_dict(state["extra_state"]["lr_scheduler"]) + train_dataloader.load_state_dict(state["extra_state"]["train_dataloader"]) + environ_meter.load_state_dict(state["extra_state"]["environ_meter"]) + torch.set_rng_state(state["extra_state"]["torch_rng_state"]) + if start_step == 0: + iter(train_dataloader) + + dist.barrier() + logger.info_rank0(f"Load distributed checkpoint from {args.train.checkpoint.load_path} successfully!") + + block_diffusion_attn_mask = _build_block_diffusion_mask(args) + helper.empty_cache() + model_fwd_context, model_bwd_context = build_activation_offloading_context( + args.train.accelerator.offload_config.enable_activation, + args.train.gradient_checkpointing.enable, + args.train.accelerator.offload_config.activation_gpu_limit, + ) + model.train() + logger.info( + f"rank{args.train.local_rank} Start training, train_steps: {args.train_steps}, " + f"epochs: {args.train.num_train_epochs}" + ) + for epoch in range(start_epoch, args.train.num_train_epochs): + if hasattr(train_dataloader, "set_epoch"): + train_dataloader.set_epoch(epoch) + + data_loader_tqdm = trange( + args.train_steps, + desc=f"Epoch {epoch + 1}/{args.train.num_train_epochs}", + total=args.train_steps, + initial=start_step, + disable=args.train.local_rank != 0, + ) + data_iterator = iter(train_dataloader) + for _ in range(start_step, args.train_steps): + global_step += 1 + + try: + micro_batches: List[Dict[str, Any]] = next(data_iterator) + except StopIteration: + logger.info(f"epoch:{epoch} Dataloader finished with drop_last {args.data.dataloader.drop_last}") + break + + if global_step == 1: + helper.print_example(example=micro_batches[0], rank=args.train.local_rank) + + total_loss = 0.0 + total_consistency_loss = 0.0 + total_confidence_loss = 0.0 + synchronize() + start_time = time.time() + num_micro_steps = len(micro_batches) + + for micro_step, micro_batch in enumerate(micro_batches): + if ( + args.train.accelerator.fsdp_config.fsdp_mode == "fsdp2" + and not args.train.accelerator.fsdp_config.reshard_after_backward + and num_micro_steps > 1 + ): + if micro_step == 0: + model.set_reshard_after_backward(False) + elif micro_step == num_micro_steps - 1: + model.set_reshard_after_backward(True) + + environ_meter.add(micro_batch) + if args.data.enable_multisource: + micro_batch.pop("ds_idx", None) + micro_batch.pop("cur_token_num", None) + micro_batch.pop("source_name", None) + + micro_batch, noisy_seq_len = _prepare_micro_batch(args, micro_batch, block_diffusion_attn_mask) + labels = micro_batch.pop("labels", None) + + with model_fwd_context: + logits = model(**micro_batch, use_cache=False, output_router_logits=False).logits + noisy_logits = logits[:, :noisy_seq_len].contiguous() if args.train.block_diffusion_mode else logits + loss, consistency_loss, confidence_loss = _compute_llada2_loss( + args, + noisy_logits=noisy_logits, + labels=labels, + num_micro_steps=num_micro_steps, + ) + + with model_bwd_context: + loss.backward() + + total_loss += loss.item() + total_consistency_loss += consistency_loss.item() / num_micro_steps + total_confidence_loss += confidence_loss.item() / num_micro_steps + del micro_batch + + grad_norm = veomni_clip_grad_norm(model, args.train.optimizer.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + if args.train.confidence_beta > 0: + total_loss, total_consistency_loss, total_confidence_loss, grad_norm = all_reduce( + (total_loss, total_consistency_loss, total_confidence_loss, grad_norm), + group=get_parallel_state().fsdp_group, + ) + else: + total_loss, grad_norm = all_reduce((total_loss, grad_norm), group=get_parallel_state().fsdp_group) + synchronize() + + delta_time = time.time() - start_time + lr = max(lr_scheduler.get_last_lr()) + train_metrics = environ_meter.step(delta_time, global_step=global_step) + + postfix = f"loss: {total_loss:.4f}, grad_norm: {grad_norm:.4f}, lr: {lr:.2e}" + if args.train.confidence_beta > 0: + postfix = ( + f"loss: {total_loss:.4f}, cons: {total_consistency_loss:.4f}, " + f"conf: {total_confidence_loss:.4f}, grad_norm: {grad_norm:.4f}, lr: {lr:.2e}" + ) + data_loader_tqdm.set_postfix_str(postfix, refresh=False) + data_loader_tqdm.update() + + if args.train.global_rank == 0 and args.train.wandb.enable: + train_metrics.update( + { + "training/loss": total_loss, + "training/grad_norm": grad_norm, + "training/lr": lr, + } + ) + if args.train.confidence_beta > 0: + train_metrics.update( + { + "training/cons_loss": total_consistency_loss, + "training/conf_loss": total_confidence_loss, + } + ) + wandb.log(train_metrics, step=global_step) + + if args.train.profile.this_rank and global_step <= args.train.profile.end_step: + profiler.step() + if global_step == args.train.profile.end_step: + profiler.stop() + + if args.train.checkpoint.save_steps and global_step % args.train.checkpoint.save_steps == 0: + helper.empty_cache() + save_checkpoint_path = os.path.join(args.train.checkpoint.save_path, f"global_step_{global_step}") + state = { + "model": model, + "optimizer": optimizer, + "extra_state": { + "global_step": global_step, + "lr_scheduler": lr_scheduler.state_dict(), + "train_dataloader": train_dataloader.state_dict(), + "environ_meter": environ_meter.state_dict(), + "torch_rng_state": torch.get_rng_state(), + }, + } + Checkpointer.save(args.train.checkpoint.save_path, state, global_steps=global_step) + + dist.barrier() + logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") + + data_loader_tqdm.close() + start_step = 0 + helper.print_device_mem_info(f"VRAM usage after epoch {epoch + 1}") + if args.train.checkpoint.save_epochs and (epoch + 1) % args.train.checkpoint.save_epochs == 0: + helper.empty_cache() + save_checkpoint_path = os.path.join(args.train.checkpoint.save_path, f"global_step_{global_step}") + state = { + "model": model, + "optimizer": optimizer, + "extra_state": { + "global_step": global_step, + "lr_scheduler": lr_scheduler.state_dict(), + "train_dataloader": train_dataloader.state_dict(), + "environ_meter": environ_meter.state_dict(), + "torch_rng_state": torch.get_rng_state(), + }, + } + Checkpointer.save(args.train.checkpoint.save_path, state, global_steps=global_step) + dist.barrier() + logger.info_rank0(f"Distributed checkpoint saved at {save_checkpoint_path} successfully!") + + synchronize() + del optimizer, lr_scheduler + helper.empty_cache() + if args.train.checkpoint.save_hf_weights and save_checkpoint_path is not None: + hf_weights_path = os.path.join(save_checkpoint_path, "hf_ckpt") + save_hf_safetensor( + save_hf_safetensor_path=hf_weights_path, + ckpt_manager=args.train.checkpoint.manager, + model_assets=model_assets, + save_checkpoint_path=save_checkpoint_path, + is_rank_0=args.train.global_rank == 0, + model=model, + fqn_to_index_mapping=args.model.fqn_to_index_mapping, + ) + + dist.barrier() + dist.destroy_process_group() diff --git a/train.sh b/train.sh index 7c567b1..bc6fc9e 100644 --- a/train.sh +++ b/train.sh @@ -6,7 +6,16 @@ export TOKENIZERS_PARALLELISM=false export TORCH_NCCL_AVOID_RECORD_STREAMS=1 NNODES=${NNODES:=1} -NPROC_PER_NODE=${NPROC_PER_NODE:=$(nvidia-smi --list-gpus | wc -l)} +if [[ -z "${NPROC_PER_NODE:-}" ]]; then + if command -v nvidia-smi >/dev/null 2>&1; then + NPROC_PER_NODE=$(nvidia-smi --list-gpus | wc -l | tr -d ' ') + elif command -v npu-smi >/dev/null 2>&1; then + NPROC_PER_NODE=$(npu-smi info 2>/dev/null | awk '/^[[:space:]]*[0-9]+[[:space:]]+[0-9]+/ {print $1}' | sort -u | wc -l | tr -d ' ') + [[ "$NPROC_PER_NODE" == "0" ]] && NPROC_PER_NODE=1 + else + NPROC_PER_NODE=1 + fi +fi NODE_RANK=${NODE_RANK:=0} MASTER_ADDR=${MASTER_ADDR:=0.0.0.0} MASTER_PORT=${MASTER_PORT:=12345}