diff --git a/tests/ray/test_pack.py b/tests/ray/test_pack.py new file mode 100644 index 000000000..199ef5576 --- /dev/null +++ b/tests/ray/test_pack.py @@ -0,0 +1,166 @@ +import unittest +import torch +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.rl.base.pack import RLDataPacker + +class TestDataBatchPacker(unittest.TestCase): + def setUp(self): + self.pack_max_length = 3072 + self.split_size = 1024 + + def _create_dummy_item(self, length: int, val=1): + input_ids = torch.full((1, length), val, dtype=torch.long) + cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32) + cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32) + max_length_q = torch.tensor(length, dtype=torch.int32) + max_length_k = torch.tensor(length, dtype=torch.int32) + seq_ctx = SequenceContext( + input_ids=input_ids, + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, + num_padding=0, + device="cpu", + ) + return { + "seq_ctx": seq_ctx, + "shifted_labels": torch.full((1, length), val, dtype=torch.long), + "advantages": torch.full((1, length), float(val), dtype=torch.float), + "rollout_logprobs": torch.full((1, length), float(val), dtype=torch.float), + } + + def _run_strategy_test(self, strategy, world_size, optimizer_steps, lengths, pack_max_length, expected_padding = None): + data_batches = [self._create_dummy_item(l, val=7) for l in lengths] + total_data_tokens = sum(lengths) + + packer = RLDataPacker( + pack_max_length=pack_max_length, + world_size=world_size, + data_replicate_size=1, + optimizer_steps=optimizer_steps, + pack_strategy=strategy + ) + + packed_res, padding_tokens = packer.pack(data_batches) + + # 验证均衡性:理想情况下,balance 策略分配给各卡的 token 总数差异应该小于单个样本的最大长度 + if strategy == "balance": + rank_token_counts = [] + for rank_data in packed_res: + rank_total_valid_tokens = 0 + for step_data in rank_data: + for pack in step_data: + # 统计非零(非 padding)的有效 token 数量 + valid_tokens = (pack["seq_ctx"].input_ids != 0).sum().item() + rank_total_valid_tokens += valid_tokens + rank_token_counts.append(rank_total_valid_tokens) + + max_tokens = max(rank_token_counts) + min_tokens = min(rank_token_counts) + diff = max_tokens - min_tokens + max_sample_len = max(lengths) if lengths else 0 + self.assertLessEqual(diff, max_sample_len, + f"Balance strategy failed: Token distribution is too skewed. " + f"Rank counts: {rank_token_counts}, Max diff: {diff}") + + # 对于固定输入,验证padding_tokens是否符合预期来验证pack逻辑正确性 + if expected_padding is not None: + self.assertEqual(padding_tokens, expected_padding, f"Strategy {strategy} padding mismatch. Expected {expected_padding}, got {padding_tokens}") + + all_packs = [] + for rank_data in packed_res: + for step_data in rank_data: + for pack in step_data: + self.assertEqual(pack["seq_ctx"].input_ids.numel(), pack_max_length, f"Strategy {strategy} pack length mismatch.") + all_packs.append(pack) + + # 验证pack前后的总有效token数是否一致 + total_capacity = len(all_packs) * pack_max_length + self.assertEqual(total_capacity, total_data_tokens + padding_tokens) + + all_input_ids = torch.cat([p["seq_ctx"].input_ids for p in all_packs], dim=1) + valid_token_count = (all_input_ids != 0).sum().item() + all_labels = torch.cat([p["shifted_labels"] for p in all_packs], dim=1) + valid_label_count = (all_labels != -100).sum().item() + all_advantages = torch.cat([p["advantages"] for p in all_packs], dim=1) + valid_adv_count = (all_advantages != -100).sum().item() + + self.assertEqual(valid_token_count, total_data_tokens) + self.assertEqual(valid_label_count, total_data_tokens) + self.assertEqual(valid_adv_count, total_data_tokens) + + def test_variable_packs(self): + """随机tokens数输入, dp=2, optimizer_steps=2 + - Native: + 1. 预处理,保证样本数量能被整除, padding到1024, 这样可以与有效的样本一起Pack + [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> padding: [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800, 1024] + 2. DP Rank 切分: + rank0: [1500, 1000, 2800, 3000, 1500] + rank1: [2000, 2100, 1000, 800, 1024] + 3. Optimizer steps切分: + rank0: [1500, 1000, 2800], [3000, 1500] + rank1: [2000, 2100, 1000], [ 800, 1024] + 4 pack and padding + rank0: step0: [2500 -> 3072], [2800 -> 3072], step1: [3000 -> 3072], [1500 -> 3072], + rank1: step0: [2000 -> 3072], [2100 -> 3072], [1000 -> 3072], step1: [1824 -> 3072] + 5. 跨卡对齐pack数量: + rank0: step0: [2500 -> 3072], [2800 -> 3072], [0 -> 3072] step1: [3000 -> 3072], [1500 -> 3072], + rank1: step0: [2100 -> 3072], [2000 -> 3072], [1000 -> 3072], step1: [1824 -> 3072], [0 -> 3072] + padding_tokens: 1024 + 3072 - 2500 + 3072 - 2800 + 3072 + 3072 - 3000 + 3072 - 1500 + 3072 - 2100 + 3072 - 2000 + 3072 - 1000 + 3072 - 1824 + 3072 = 15020 + - Balance: + 1. 对原始输入数据进行排序: + [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] -> [3000, 2800, 2100, 2000, 1500, 1500, 1000, 1000, 800] + 2. 相近长度的N个样本分到N张卡上, 每N个样本为作为N张卡的一次optimizer step的数据 + rank0: [3000, 1500, 800], [2100, 1000], + rank1: [2800, 1500], [2000, 1000], + 3. pack and pad: + rank0: step0: [3000 -> 3072], [2300 -> 3072], step1: [2100 ->3072], [1000 -> 3072], + rank1: step0: [2800 -> 3072], [1500 -> 3072], step1: [3000 ->3072], [. 0 -> 3072], + 4. 跨卡对齐pack数量: + skip + padding_tokens: 3072 - 3000 + 3072 - 2300 + 3072 - 2100 + 3072 - 1000 + 3072 - 2800 + 3072 - 1500 + 3072 - 3000 + 3072 = 8876 + - Greedy: 追求 Pack 填充率最大化 + 1. pack and padding: + Pack 1: [1500, 1000] -> [2500 -> 3072] + Pack 2: [2800] -> [2800 -> 3072] + Pack 3: [3000] -> [3000 -> 3072] + Pack 4: [1500] -> [1500 -> 3072] + Pack 5: [2000] -> [2000 -> 3072] + Pack 6: [2100] -> [2100 -> 3072] + Pack 7: [1000, 800] -> [1800 -> 3072] + Pack 8: [ ] -> [0 -> 3072] (padding) + 2. DP 切分: + rank0: [Pack 1, Pack 2, Pack 3, Pack 4] + rank1: [Pack 5, Pack 6, Pack 7, Pack 8] + 3. Opitmizer steps 切分: + rank0: step0: [Pack 1, Pack 2], step1: [Pack 3, Pack 4] + rank1: step0: [Pack 5, Pack 6], step1: [Pack 7, Pack 8] + 4. 跨卡对齐pack数量: + skip + padding_tokens: 3072 - 2500 + 3072 - 2800 + 3072 - 3000 + 3072 - 1500 + 3072 - 2000 + 3072 - 2100 + 3072 - 1800 + 3072 = 8876 + """ + lengths = [1500, 1000, 2800, 3000, 1500, 2000, 2100, 1000, 800] + self._run_strategy_test("native", 2, 2, lengths, self.pack_max_length, 15020) + self._run_strategy_test("balance", 2, 2, lengths, self.pack_max_length, 8876) + self._run_strategy_test("greedy", 2, 2, lengths, self.pack_max_length, 8876) + + def test_imbalance_dp_size(self): + lengths = [500] + for strat in ["native", "balance", "greedy"]: + self._run_strategy_test(strat, 2, 1, lengths, self.pack_max_length, 5644) + + def test_imbalanced_steps(self): + lengths = [100, 200, 2500, 3000, 50, 400, 1000, 1500] + self._run_strategy_test("native", 2, 4, lengths, self.pack_max_length, 15826) + self._run_strategy_test("balance", 2, 4, lengths, self.pack_max_length, 15826) + self._run_strategy_test("greedy", 2, 4, lengths, self.pack_max_length, 3538) + + def test_random_lengths(self): + import random + lengths = [random.randint(1, 32768) for _ in range(1024)] + for strat in ["native", "balance", "greedy"]: + self._run_strategy_test(strat, 8, 16, lengths, 32768) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/tests/ray/test_rl_train_with_sft.py b/tests/ray/test_rl_train_with_sft.py index be7dc9381..6f2ec708d 100644 --- a/tests/ray/test_rl_train_with_sft.py +++ b/tests/ray/test_rl_train_with_sft.py @@ -67,7 +67,7 @@ def setUp(self): dict( seq_ctx=SequenceContext.from_input_ids((input_ids,), device="cpu"), shifted_labels=shifted_labels, - advantage=advantages[i].item(), + advantages=advantages[i].item(), ) ) self.data_batches = data_batches @@ -156,10 +156,10 @@ def test_rl_train_with_sft(self): ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) ray.get(train_controller.save.remote(os.path.join(self.temp_dir, "save_test"), no_save_optimizer=True)) - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) + train_log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) efficient_attn_ratio_list = [] - for log_info in log_infos: - efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + for log_info in train_log_infos['worker_log_infos']: + efficient_attn_ratio_list.append(log_info["sft_train_metrics"]['efficient_attn_ratio']) assert all([efficient_attn_ratio > 0 for efficient_attn_ratio in efficient_attn_ratio_list]) ray.kill(train_controller) @@ -170,10 +170,10 @@ def test_rl_train_with_sft(self): ) ray.get(train_controller.resume.remote(load_checkpoint_cfg)) - log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) + train_log_infos = ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=1)) new_efficient_attn_ratio_list = [] - for log_info in log_infos: - new_efficient_attn_ratio_list.append(log_info['sft_train_metrics']['efficient_attn_ratio']) + for log_info in train_log_infos['worker_log_infos']: + new_efficient_attn_ratio_list.append(log_info["sft_train_metrics"]['efficient_attn_ratio']) efficient_attn_ratio_list.sort() new_efficient_attn_ratio_list.sort() diff --git a/xtuner/v1/rl/base/__init__.py b/xtuner/v1/rl/base/__init__.py index d75603b57..e57889f7a 100644 --- a/xtuner/v1/rl/base/__init__.py +++ b/xtuner/v1/rl/base/__init__.py @@ -1,6 +1,13 @@ -from .controller import TrainingController, TrainingControllerProxy +from .controller import TrainingController, TrainingControllerProxy, TrainingLogInfo from .loss import BaseRLLossConfig, RLLossContextInputItem -from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem +from .worker import ( + TrainingWorker, + TrainingWorkerClass, + TrainingWorkerProxy, + WorkerConfig, + WorkerInputItem, + WorkerLogItem, +) __all__ = [ @@ -13,4 +20,6 @@ "BaseRLLossConfig", "RLLossContextInputItem", "WorkerLogItem", + "WorkerInputItem", + "TrainingLogInfo", ] diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index b500b53e4..985a87860 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,267 +1,82 @@ -import math import os -from typing import Literal, TypedDict +import time +from pathlib import Path +from typing import Literal import ray import torch from ray.actor import ActorProxy +from typing_extensions import TypedDict -from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.model.compose.base import BaseComposeConfig from xtuner.v1.train.trainer import LoadCheckpointConfig -from xtuner.v1.utils import ray_method +from xtuner.v1.utils import get_logger, ray_method -from .worker import TrainingWorker, WorkerLogItem +from .pack import RLDataPacker +from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours -class ColateItem(TypedDict): - seq_ctx: SequenceContext - shifted_labels: torch.Tensor - advantage: float - rollout_logprobs: torch.Tensor | None +class TrainingLogInfo(TypedDict): + worker_log_infos: list[WorkerLogItem] + padding_tokens: int + pack_time: float + train_time: float class RawTrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers - - # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack - def _get_pack_infos(self, dataset, num_tokens, target, random=None): - inds = list(range(len(dataset))) - if random is not None: - random.shuffle(inds) - - item_buffer = [] - length_buffer = [] - longest = 0 - - pack_infos = [] - for shfl_i in inds: - if num_tokens[shfl_i] + sum(length_buffer) <= target: - item_buffer.append(shfl_i) - length_buffer.append(num_tokens[shfl_i]) - longest = max(longest, num_tokens[shfl_i]) - else: - if len(item_buffer) > 0: - info = { - "indices": item_buffer, - "longest": int(longest), - } - pack_infos.append(info) - - item_buffer = [shfl_i] - length_buffer = [num_tokens[shfl_i]] - longest = num_tokens[shfl_i] - - if len(item_buffer) > 0: - info = { - "indices": item_buffer, - "longest": int(longest), - } - - pack_infos.append(info) - - return pack_infos - - # TODO(hha): 这个逻辑不够通用,和模型绑定了 - def _packing(self, data_batches, pack_max_length, language_cfg): - pack_infos = self._get_pack_infos( - data_batches, - [data["seq_ctx"].input_ids.numel() for data in data_batches], - pack_max_length, + refs = [ + self.workers[0].get_model_cfg.remote(), + self.workers[0].get_worker_cfg.remote(), + self.workers[0].get_data_replicate_size.remote(), + ] + self.model_cfg, self.worker_cfg, self.data_replicate_size = ray.get(refs) + dp_ranks_handle = [worker.get_dp_rank.remote() for worker in self.workers] + self.worker_dp_ranks = ray.get(dp_ranks_handle) + self.pack_max_length = self.worker_cfg.pack_max_length + self.pack_strategy = self.worker_cfg.pack_strategy + self.data_packer = RLDataPacker( + pack_max_length=self.pack_max_length, + world_size=len(self.workers), + data_replicate_size=self.data_replicate_size, + optimizer_steps=self.worker_cfg.optimizer_steps, + pack_strategy=self.pack_strategy, + worker_log_dir=self.worker_cfg.log_dir, ) - packed_data_batches = [] - - is_qwen3_vl = False - if len(data_batches[0]["seq_ctx"].position_ids.shape) == 3: - is_qwen3_vl = True - - has_rollout_routed_experts = False - if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: - assert language_cfg is not None - has_rollout_routed_experts = True - n_routed_experts = language_cfg.n_routed_experts - - for pack_info in pack_infos: - indices = pack_info["indices"] - total_len = sum([data_batches[i]["seq_ctx"].input_ids.shape[1] for i in indices]) - pad_len = pack_max_length - total_len - seq_ctx_list = [data_batches[i]["seq_ctx"] for i in indices] - label_list = [data_batches[i]["shifted_labels"] for i in indices] - advantage_list = [data_batches[i]["advantage"] for i in indices] - - rollout_logprobs_list = None - if "rollout_logprobs" in data_batches[0] and data_batches[0]["rollout_logprobs"] is not None: - rollout_logprobs_list = [data_batches[i]["rollout_logprobs"] for i in indices] - - if pad_len > 0: - # Reduce the attn calculation time by using multiple short sequence packs - pad_tokens = tuple( - torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu") - for _ in range(pad_len // 1024) - ) - if pad_len % 1024 > 0: - pad_tokens = pad_tokens + ( - torch.zeros(1, pad_len % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu"), - ) - pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") - pad_seq_ctx.num_padding = pad_len - pad_labels = torch.full( - (1, pad_len), - -100, - dtype=data_batches[0]["shifted_labels"].dtype, - device=data_batches[0]["shifted_labels"].device, - ) - if is_qwen3_vl: - _position_ids_list = [] - for pad_token in pad_tokens: - _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) - _position_ids_list.append(_position_ids) - pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) - - if has_rollout_routed_experts: - pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1)) - pad_seq_ctx.rollout_routed_experts = pad_rand_index - - seq_ctx_list.append(pad_seq_ctx) - label_list.append(pad_labels) - advantage_list.extend( - [-100] * math.ceil(pad_len / 1024) - ) # can be any number, pad tokens are excluded from the calculation of the loss function. - - if rollout_logprobs_list is not None: - pad_rollout_logprobs = torch.zeros( - 1, - pad_len, - dtype=data_batches[0]["rollout_logprobs"].dtype, - device=data_batches[0]["shifted_labels"].device, - ) - rollout_logprobs_list.append(pad_rollout_logprobs) - - seq_ctx = SequenceContext.cat(seq_ctx_list) - shifted_labels = torch.cat(label_list, dim=1) # (1, max_len) - advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples) - cu_seq_lens_q = seq_ctx.cu_seq_lens_q - num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] - advantages = torch.repeat_interleave(advantages, num_tokens, dim=1) # (1, max_len) - - rollout_logprobs = None - if rollout_logprobs_list is not None: - rollout_logprobs = torch.cat(rollout_logprobs_list, dim=1) # (1, max_len) - - packed_data_batches.append( - { - "seq_ctx": seq_ctx, - "shifted_labels": shifted_labels, - "advantages": advantages, - "rollout_logprobs": rollout_logprobs, - } - ) - return packed_data_batches - - def _grouped_by_max_length(self, packed_data_batches): - # sort 过后可能第一个 batch 会有很多 pad tokens,因为最后一个 pack 可能只有少量真实数据。 - # 比如组成了 16 个 pack,第 16 个 pack 可能只有几条真实数据,剩下的都是 pad tokens。 - # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。 - return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) + log_dir = self.worker_cfg.log_dir + self.log_dir = None + if log_dir is not None: + self.log_dir = Path(log_dir) if isinstance(log_dir, str) else log_dir + self.logger = get_logger(log_dir=self.log_dir, tag="TrainingController") + else: + self.logger = get_logger() @ray_method - def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> list[WorkerLogItem]: - has_rollout_routed_experts = False - language_cfg = None - if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: - model_cfg = ray.get(self.workers[0].get_model_cfg.remote()) # type: ignore[attr-defined] - has_rollout_routed_experts = True - language_cfg = model_cfg - if isinstance(model_cfg, BaseComposeConfig): - language_cfg = model_cfg.text_config - - packed_data_batches = self._packing(data_batches, pack_max_length, language_cfg) - # packed_data_batches = self._grouped_by_max_length(packed_data_batches) - - # TODO(hha): 这个逻辑不够通用,和模型绑定了 - is_qwen3_vl = False - if len(packed_data_batches[0]["seq_ctx"].position_ids.shape) == 3: - is_qwen3_vl = True - - # todo: support round up - num_packed_data_batches = len(packed_data_batches) - data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined] - dp_size = len(self.workers) // data_replicate_size - pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches - if pad_num > 0: - # Reduce the attn calculation time by using multiple short sequence packs - assert data_batches[0]["seq_ctx"].input_ids is not None - pad_tokens = tuple( - torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu") - for _ in range(pack_max_length // 1024) - ) - if pack_max_length % 1024 > 0: - assert data_batches[0]["seq_ctx"].input_ids is not None - pad_tokens = pad_tokens + ( - torch.zeros( - 1, pack_max_length % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu" - ), - ) - pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") # type: ignore - pad_seq_ctx.num_padding = pack_max_length - if is_qwen3_vl: - _position_ids_list = [] - for pad_token in pad_tokens: - _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) - _position_ids_list.append(_position_ids) - pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) # type: ignore - - pad_shifted_labels = torch.full( - (1, pack_max_length), - -100, - dtype=packed_data_batches[0]["shifted_labels"].dtype, - device="cpu", - ) - pad_advantages = torch.full( - (1, pack_max_length), - -100, - dtype=packed_data_batches[0]["advantages"].dtype, - device="cpu", - ) - - if has_rollout_routed_experts: - pad_rand_index = torch.randint( - low=0, - high=1, - size=(1, 1, 1), # add dummy data, true data will be initialized in train worker.fit - ) - pad_seq_ctx.rollout_routed_experts = pad_rand_index - - pad_rollout_logprobs = None - if "rollout_logprobs" in packed_data_batches[0] and packed_data_batches[0]["rollout_logprobs"] is not None: - pad_rollout_logprobs = torch.zeros( - 1, pack_max_length, dtype=packed_data_batches[0]["rollout_logprobs"].dtype, device="cpu" - ) - pad_data = { - "seq_ctx": pad_seq_ctx, - "shifted_labels": pad_shifted_labels, - "advantages": pad_advantages, - "rollout_logprobs": pad_rollout_logprobs, - } - pad_data_samples = [pad_data for _ in range(pad_num)] - packed_data_batches = packed_data_batches + pad_data_samples - - print(f"len(packed_data_batches): {len(packed_data_batches)}") - + def fit(self, data_batches: list[WorkerInputItem], pack_max_length: int, rollout_idx: int) -> TrainingLogInfo: + start_time = time.perf_counter() + packed_data_batches, padding_tokens_num = self.data_packer.pack(data_batches) + pack_end_time = time.perf_counter() handles = [] for worker_idx, worker in enumerate(self.workers): + dp_rank = self.worker_dp_ranks[worker_idx] handles.append( worker.fit.remote( # type: ignore[attr-defined] - data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], + data_batches=packed_data_batches[dp_rank], rollout_idx=rollout_idx, ) ) - log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) - return log_infos + train_end_time = time.perf_counter() + worker_log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) + train_log_info: TrainingLogInfo = { + "worker_log_infos": worker_log_infos, + "pack_time": pack_end_time - start_time, + "train_time": train_end_time - pack_end_time, + "padding_tokens": padding_tokens_num, + } + return train_log_info @ray_method def offload(self, target: Literal["model", "optimizer", "all"] = "all"): diff --git a/xtuner/v1/rl/base/pack.py b/xtuner/v1/rl/base/pack.py new file mode 100644 index 000000000..83030ebe3 --- /dev/null +++ b/xtuner/v1/rl/base/pack.py @@ -0,0 +1,391 @@ +import math +import random +from pathlib import Path +from typing import cast + +import numpy as np +import torch + +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.datasets.sampler import get_length_grouped_indices +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.base.worker import WorkerInputItem +from xtuner.v1.utils import get_logger + + +# TODO: use expand_soft pack strategy in sft for greedy pack +def get_soft_pack_infos(data_batches, num_tokens, target): + pack_infos = [] + current_indices = [] + current_len = 0 + current_max = 0 + + for i, token_len in enumerate(num_tokens): + if current_len + token_len <= target: + current_indices.append(i) + current_len += token_len + current_max = max(current_max, token_len) + else: + if current_indices: + pack_infos.append( + { + "indices": current_indices, + "longest": int(current_max), + } + ) + current_indices = [i] + current_len = token_len + current_max = token_len + + if current_indices: + pack_infos.append( + { + "indices": current_indices, + "longest": int(current_max), + } + ) + return pack_infos + + +class RLDataPacker: + def __init__( + self, + pack_max_length: int, + world_size: int, + data_replicate_size: int, + optimizer_steps: int, + pack_strategy: str = "greedy", + model_cfg: TransformerConfig | None = None, + worker_log_dir: str | None = None, + seed: int = 42, + ): + self.pack_max_length = pack_max_length + self.world_size = world_size + self.data_replicate_size = data_replicate_size + self.optimizer_steps = optimizer_steps + self.split_size = 1024 + if worker_log_dir is not None: + self.worker_log_dir = Path(worker_log_dir) if isinstance(worker_log_dir, str) else worker_log_dir + self.logger = get_logger(log_dir=self.worker_log_dir, tag="TrainingController") + else: + self.logger = get_logger() + + self.data_batch_properties = { + "is_qwen3_vl": False, + "has_rollout_routed_experts": False, + "has_rollout_logprobs": False, + "n_routed_experts": None, + } + self.strategy_map = { + "greedy": self.greedy_pack_and_split, + "balance": self.balance_split_and_pack, + "native": self.native_split_and_pack, + } + if pack_strategy not in self.strategy_map: + raise ValueError(f"Unknown packing strategy: {pack_strategy}") + self._impl = self.strategy_map[pack_strategy] + self.dp_size = self.world_size // self.data_replicate_size + self.padding_tokens = 0 + self.model_cfg = model_cfg + self.seed = seed + + def pack(self, data_batches: list[WorkerInputItem]) -> tuple[list[list[list[WorkerInputItem]]], int]: + self.padding_tokens = 0 + if not data_batches: + return [], 0 + self._set_data_batch_properties(data_batches) + return self._impl(data_batches), self.padding_tokens + + def native_split_and_pack(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 1. 预处理,保证样本数量可以被 dp_size 整除 + if len(data_batches) % self.dp_size != 0: + pad_num = self.dp_size - (len(data_batches) % self.dp_size) + padding_item = self._create_padding_item(self.split_size, self.pack_max_length) + data_batches.extend([padding_item] * pad_num) + + # 2. 按照 dp_size 切分样本 + batches_per_dp_group: list[list[WorkerInputItem]] = np.array_split(data_batches, self.dp_size) + actual_optimizer_steps = min(len(batches_per_dp_group[0]), self.optimizer_steps) + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(actual_optimizer_steps)] for _ in range(self.dp_size) + ] + max_packs_per_step = [0] * actual_optimizer_steps + + for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group): + # 3. 按照 actual_optimizer_steps 切分样本 + batches_for_optim_steps = np.array_split(dp_worker_data_batches, actual_optimizer_steps) + for step_idx, step_mini_batches in enumerate(batches_for_optim_steps): + # 4. 对每个 optimizer step 的样本进行打包 + each_step_pack_list = self._pack(step_mini_batches, self.pack_max_length) + packed_data_batches[dp_rank][step_idx] = each_step_pack_list + max_packs_per_step[step_idx] = max( + max_packs_per_step[step_idx], len(packed_data_batches[dp_rank][step_idx]) + ) + + self.logger.info(f"Gradient accumulation for each optimizer steps: {max_packs_per_step}") + + # 5. padding for each worker to have same number of packs in each optimizer step + for step_idx in range(actual_optimizer_steps): + max_packs = max_packs_per_step[step_idx] + for dp_rank in range(self.dp_size): + num_current_packs = len(packed_data_batches[dp_rank][step_idx]) + num_padding_packs = max_packs - num_current_packs + + if num_padding_packs > 0: + padding_items = [ + self._create_padding_item(self.pack_max_length, self.pack_max_length) + for _ in range(num_padding_packs) + ] + packed_data_batches[dp_rank][step_idx].extend(padding_items) + return packed_data_batches + + def balance_split_and_pack(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 1. 保证每张卡获取的样本总长度大致相等 + max_lengths = self._get_seqlen_from_data_batches(data_batches) + + torch_generator = torch.Generator().manual_seed(self.seed) + random_generator = random.Random(self.seed) + + indices = get_length_grouped_indices( + max_lengths=max_lengths, + group_batch_size=len(data_batches), + group_size=self.dp_size, + torch_generator=torch_generator, + random_generator=random_generator, + ) + + partitioned_data: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(self.optimizer_steps)] for _ in range(self.dp_size) + ] + + # 2. 根据indices将样本分配到每张卡的每个 optimizer step 上 + for i, idx in enumerate(indices): + dp_rank = i % self.dp_size + step_idx = (i // self.dp_size) % self.optimizer_steps + partitioned_data[dp_rank][step_idx].append(data_batches[idx]) + + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(self.optimizer_steps)] for _ in range(self.dp_size) + ] + + max_packs_per_step = [0] * self.optimizer_steps + + for dp_rank in range(self.dp_size): + for step_idx in range(self.optimizer_steps): + # 3. 对每个卡每个 optimizer step 的样本进行打包 + step_data = partitioned_data[dp_rank][step_idx] + packed_step_data = self._pack(step_data, self.pack_max_length) + packed_data_batches[dp_rank][step_idx] = packed_step_data + max_packs_per_step[step_idx] = max( + max_packs_per_step[step_idx], len(packed_data_batches[dp_rank][step_idx]) + ) + + # 4. padding for each worker to have same number of packs in each optimizer step + for step_idx in range(self.optimizer_steps): + max_packs = max_packs_per_step[step_idx] + for dp_rank in range(self.dp_size): + num_current_packs = len(packed_data_batches[dp_rank][step_idx]) + num_padding_packs = max_packs - num_current_packs + + if num_padding_packs > 0: + padding_items = [ + self._create_padding_item(self.pack_max_length, self.pack_max_length) + for _ in range(num_padding_packs) + ] + packed_data_batches[dp_rank][step_idx].extend(padding_items) + return packed_data_batches + + def greedy_pack_and_split(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 1. 使用贪心算法将所有样本打包成一个一维的 pack 列表。 + total_data_batches = self._pack(data_batches, self.pack_max_length) + # 2. 为了均匀分配,填充整个 batch,使其总 pack 数能被 dp_size 整除。 + dp_size = self.world_size // self.data_replicate_size + num_packed_data_batches = len(total_data_batches) + pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches + if pad_num > 0: + pad_data_samples = [ + self._create_padding_item(self.pack_max_length, self.pack_max_length) for _ in range(pad_num) + ] + total_data_batches = total_data_batches + pad_data_samples + + # 3. 将填充后的 pack 列表按 dp_size 和 optimizer_steps 重新分配。 + each_dp_batches_num = len(total_data_batches) // dp_size + if each_dp_batches_num < self.optimizer_steps: + iters_per_step = 1 # each optimizer step has at least one batch + actual_optimizer_steps = each_dp_batches_num + else: + iters_per_step = math.ceil(each_dp_batches_num / self.optimizer_steps) + actual_optimizer_steps = math.ceil(each_dp_batches_num / iters_per_step) + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(actual_optimizer_steps)] for _ in range(dp_size) + ] + for dp_rank in range(dp_size): + for step in range(actual_optimizer_steps): + start_idx = dp_rank * each_dp_batches_num + step * iters_per_step + end_idx = min(start_idx + iters_per_step, each_dp_batches_num * (dp_rank + 1)) + packed_data_batches[dp_rank][step] = total_data_batches[start_idx:end_idx] + return packed_data_batches + + def _get_seqlen_from_data_batches(self, data_batches: list[WorkerInputItem]) -> list[int]: + seqlen_list = [] + for data in data_batches: + assert data["seq_ctx"].input_ids.numel() <= self.pack_max_length, ( # type: ignore[union-attr] + f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {self.pack_max_length}" # type: ignore[union-attr] + ) + seqlen_list.append(data["seq_ctx"].input_ids.numel()) # type: ignore[union-attr] + return seqlen_list + + def _set_data_batch_properties(self, data_batches: list[WorkerInputItem]): + if not data_batches: + return + + first_item = data_batches[0] + seq_ctx = first_item["seq_ctx"] + + self.data_batch_properties["is_qwen3_vl"] = ( + seq_ctx.position_ids is not None and len(seq_ctx.position_ids.shape) == 3 + ) + self.data_batch_properties["has_rollout_logprobs"] = ( + "rollout_logprobs" in first_item and first_item["rollout_logprobs"] is not None + ) + self.data_batch_properties["has_rollout_routed_experts"] = seq_ctx.rollout_routed_experts is not None + + language_cfg = None + if self.data_batch_properties["has_rollout_routed_experts"]: + language_cfg = self.model_cfg + if isinstance(self.model_cfg, BaseComposeConfig): + language_cfg = self.model_cfg.text_config + + self.data_batch_properties["n_routed_experts"] = ( + language_cfg.n_routed_experts if language_cfg is not None else None + ) + self.logger.info(f"Data batch properties set: {self.data_batch_properties}") + + def _pack(self, data_batches: list[WorkerInputItem], pack_max_length: int) -> list[WorkerInputItem]: + seqlen_list = self._get_seqlen_from_data_batches(data_batches) + total_length = sum(seqlen_list) + each_step_pack_list: list[WorkerInputItem] = [] + if total_length > pack_max_length: + # TODO: add expand_soft pack strategy + pack_infos = get_soft_pack_infos( + data_batches, + seqlen_list, + pack_max_length, + ) + for pack_info in pack_infos: + indices = pack_info["indices"] + batch4pack = [data_batches[i] for i in indices] + each_step_pack_list.append(self._single_pack(batch4pack, pack_max_length)) + else: + each_step_pack_list.append(self._single_pack(data_batches, pack_max_length)) + return each_step_pack_list + + def _single_pack(self, data_batches: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem: + seq_ctx_list = [item["seq_ctx"] for item in data_batches] + label_list = [item["shifted_labels"] for item in data_batches] + advantage_list = [] + for item in data_batches: + advantages = item["advantages"].reshape(1, -1) + advantage_list.append(advantages) + + rollout_logprobs_list = [ + item["rollout_logprobs"] if self.data_batch_properties["has_rollout_logprobs"] else None + for item in data_batches + ] + seqlen_list = self._get_seqlen_from_data_batches(data_batches) + cur_length = sum(seqlen_list) + padding_len = pack_max_length - cur_length + + padding_item: WorkerInputItem | None = None + if padding_len > 0: + padding_item = self._create_padding_item(padding_len, pack_max_length) + seq_ctx_list.append(padding_item["seq_ctx"]) + label_list.append(padding_item["shifted_labels"]) + advantage_list.append(padding_item["advantages"]) + rollout_logprobs_list.append(padding_item["rollout_logprobs"]) + + packed_seq_ctx = SequenceContext.cat(seq_ctx_list) + packed_shifted_labels = torch.cat(label_list, dim=1) # type: ignore[arg-type] + packed_shifted_labels = cast(torch.LongTensor, packed_shifted_labels) + packed_advantages = torch.cat(advantage_list, dim=1) + if self.data_batch_properties["has_rollout_logprobs"]: + cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list] + packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1) + else: + packed_rollout_logprobs = None + + optimizer_step_packs: WorkerInputItem = { + "seq_ctx": packed_seq_ctx, + "shifted_labels": packed_shifted_labels, + "advantages": packed_advantages, + "rollout_logprobs": packed_rollout_logprobs, + } + packed_input_ids = cast(torch.Tensor, packed_seq_ctx.input_ids) + assert packed_input_ids.numel() == pack_max_length, ( + f"Packed seq ctx length {packed_input_ids.numel()} does not match pack_max_length {pack_max_length}" + f"padding input_ids length: {padding_item['seq_ctx'].input_ids.shape if padding_item else 0}" # type: ignore[union-attr] + ) + assert packed_seq_ctx.num_padding == (packed_advantages != -100).sum().item(), ( + f"Packed seq ctx num_padding {packed_seq_ctx.num_padding} and packed advantages num_padding " + f"{(packed_advantages != -100).sum().item()} mismatch after packing." + ) + return optimizer_step_packs + + def _create_padding_item( + self, + pad_len: int, + pack_max_length: int, + ) -> WorkerInputItem: + # padding input_ids + self.padding_tokens += pad_len + pad_tokens = tuple( + torch.zeros(1, self.split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // self.split_size) + ) + if pad_len % self.split_size > 0: + pad_tokens = pad_tokens + (torch.zeros(1, pad_len % self.split_size, dtype=torch.long, device="cpu"),) + pad_tokens = cast(tuple[torch.LongTensor, ...], pad_tokens) + pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") + pad_seq_ctx.num_padding = pad_len + + # padding mm positions_ids + if self.data_batch_properties["is_qwen3_vl"]: + _position_ids_list = [] + for pad_token in pad_tokens: + _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) + _position_ids_list.append(_position_ids) + position_ids = torch.cat(_position_ids_list, dim=-1) + position_ids = cast(torch.LongTensor, position_ids) + pad_seq_ctx.position_ids = position_ids + + # padding rollout routed experts + if self.data_batch_properties["has_rollout_routed_experts"]: + assert self.data_batch_properties["n_routed_experts"], ( + "n_routed_experts must be provided when has_rollout_routed_experts is True" + ) + if pad_len == pack_max_length: + pad_rand_index = torch.randint( + low=0, high=1, size=(1, 1, 1) + ) # add dummy data, true data will be initialized in train worker.fit + else: + pad_rand_index = torch.randint( + low=0, high=self.data_batch_properties["n_routed_experts"], size=(pad_len, 1, 1) + ) + pad_seq_ctx.rollout_routed_experts = pad_rand_index + + pad_labels = cast(torch.LongTensor, torch.full((1, pad_len), -100, dtype=torch.int64, device="cpu")) + + pad_advantage = torch.full((1, pad_len), -100, dtype=torch.float32, device="cpu") + pad_rollout_logprobs = ( + torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") + if self.data_batch_properties["has_rollout_logprobs"] + else None + ) + + padding_item: WorkerInputItem = { + "seq_ctx": pad_seq_ctx, + "shifted_labels": pad_labels, + "advantages": pad_advantage, + "rollout_logprobs": pad_rollout_logprobs, + } + return padding_item diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 86012ca40..ca70f1501 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -1,10 +1,9 @@ import json -import math import os import time from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, TypeAlias, TypedDict, cast +from typing import Dict, Iterable, List, Literal, TypeAlias, TypedDict, cast import ray import requests @@ -143,7 +142,7 @@ class WorkerConfig(BaseModel): log_dir: str | Path | None = None update_weight_bucket_size_in_gb: float = 0.5 # 512MB seed: None | int = None # if None, use RLTrainer seed - + pack_strategy: Literal["greedy", "balance", "native"] = "greedy" # sft config sft_dataloader_cfg: DataloaderConfig | None = None sft_global_batch_size: int = -1 @@ -416,113 +415,121 @@ def _get_rl_other_log(self, other_log: OtherLog) -> RLOtherLog: return rl_other_log @ray_method - def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: + def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int) -> WorkerLogItem: # NOTE: sglang会清除logger handle, 重新创建 self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") loss_cfg = self.config.loss_cfg num_batches = len(data_batches) - iters_per_step = math.ceil(num_batches / self._optimizer_steps) - if num_batches < self._optimizer_steps: - self.logger.info( - f"Optimizer only step once because num_batches {num_batches} < optimizer_steps {self._optimizer_steps}." + actual_optimizer_steps = self.config.optimizer_steps + if num_batches < actual_optimizer_steps: + actual_optimizer_steps = num_batches + self.logger.warning( + f"data_batches num {num_batches} is less than optimizer_steps {self.config.optimizer_steps}, " + f"set optimizer_steps to {num_batches}" ) + assert num_batches == actual_optimizer_steps, ( + f"data_batches num {num_batches} must be equal to optimizer_steps {actual_optimizer_steps}" + ) + packd_batch_num_per_step = [] seq_ctx_list: list[SequenceContext] = [] loss_ctx_input_list: list[RLLossContextInputItem] = [] rollout_logprobs_list: list[torch.Tensor | None] = [] - # convert dummy padding experts to real size - language_cfg = ( self.config.model_cfg.text_config if isinstance(self.config.model_cfg, BaseComposeConfig) else self.config.model_cfg ) - for data in data_batches: - seq_ctx = data["seq_ctx"] - pixel_values = seq_ctx.pixel_values - if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): - assert isinstance(pixel_values, list), ( - f"pixel_values should be list of tensor, got {type(pixel_values)}" - ) - pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] - pixel_values = torch.cat(pixel_values, dim=0) - seq_ctx.pixel_values = pixel_values - - rollout_routed_experts = seq_ctx.rollout_routed_experts - if rollout_routed_experts is not None: - to_free_routed_expert_refs: list[ray.ObjectRef] = [] - if isinstance(rollout_routed_experts, list): - # list[n,l,e] - out_rollout_routed_expert = [] - for rollout_routed_expert in rollout_routed_experts: - if isinstance(rollout_routed_expert, torch.Tensor): - rollout_routed_experts_tensor = torch.randint( - low=0, - high=language_cfg.n_routed_experts, - size=( - rollout_routed_expert.size(0), - language_cfg.num_hidden_layers, - language_cfg.num_experts_per_tok, - ), - ) - out_rollout_routed_expert.append(rollout_routed_experts_tensor) - else: - rollout_routed_expert_refs = rollout_routed_expert - rollout_routed_expert = ray.get(rollout_routed_expert_refs) - # free obj store explicitly - if self.sp_mesh is None or self.sp_mesh.size() == 1: - ray._private.internal_api.free(rollout_routed_expert_refs) + for step_data_batches in data_batches: + # number of packed batch num means the gradient accumulation steps + packd_batch_num_per_step.append(len(step_data_batches)) + for data in step_data_batches: + seq_ctx = data["seq_ctx"] + pixel_values = seq_ctx.pixel_values + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + assert isinstance(pixel_values, list), ( + f"pixel_values should be list of tensor, got {type(pixel_values)}" + ) + pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + seq_ctx.pixel_values = pixel_values + + rollout_routed_experts = seq_ctx.rollout_routed_experts + if rollout_routed_experts is not None: + to_free_routed_expert_refs: list[ray.ObjectRef] = [] + if isinstance(rollout_routed_experts, list): + # list[n,l,e] + out_rollout_routed_expert = [] + for rollout_routed_expert in rollout_routed_experts: + if isinstance(rollout_routed_expert, torch.Tensor): + rollout_routed_experts_tensor = torch.randint( + low=0, + high=language_cfg.n_routed_experts, + size=( + rollout_routed_expert.size(0), + language_cfg.num_hidden_layers, + language_cfg.num_experts_per_tok, + ), + ) + out_rollout_routed_expert.append(rollout_routed_experts_tensor) else: - if self.sp_mesh.get_local_rank() == 0: - # only free once of sp mesh - to_free_routed_expert_refs.append(rollout_routed_expert_refs) - out_rollout_routed_expert.append(torch.as_tensor(rollout_routed_expert, dtype=torch.long)) - - seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e - else: - assert isinstance(rollout_routed_experts, torch.Tensor), ( - f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" - ) - rollout_routed_experts_tensor = torch.randint( - low=0, - high=language_cfg.n_routed_experts, - size=( - self.config.pack_max_length, - language_cfg.num_hidden_layers, - language_cfg.num_experts_per_tok, - ), - ) - seq_ctx.rollout_routed_experts = rollout_routed_experts_tensor - - assert seq_ctx.input_ids is not None, "input_ids is None" - assert seq_ctx.rollout_routed_experts.size(0) == seq_ctx.input_ids.size(1) - - if self.sp_mesh is not None and self.sp_mesh.size() > 1: - dist.barrier() - for free_routed_expert_refs in to_free_routed_expert_refs: - ray._private.internal_api.free(free_routed_expert_refs) - del to_free_routed_expert_refs - - seq_ctx = data["seq_ctx"].to(DEVICE) - rollout_logprobs = data.get("rollout_logprobs", None) - if rollout_logprobs is not None: - rollout_logprobs = rollout_logprobs.to(DEVICE) - rollout_logprobs_list.append(rollout_logprobs) - loss_ctx_input = RLLossContextInputItem( - shifted_labels=data["shifted_labels"], - advantages=data["advantages"], - rollout_logprobs=rollout_logprobs, - ).to(DEVICE) - if self.sp_mesh.size() > 1: - seq_ctx = seq_ctx.split(self.sp_mesh) - loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) - seq_ctx_list.append(seq_ctx) - loss_ctx_input_list.append(loss_ctx_input) + rollout_routed_expert_refs = rollout_routed_expert + rollout_routed_expert = ray.get(rollout_routed_expert_refs) + # free obj store explicitly + if self.sp_mesh is None or self.sp_mesh.size() == 1: + ray._private.internal_api.free(rollout_routed_expert_refs) + else: + if self.sp_mesh.get_local_rank() == 0: + # only free once of sp mesh + to_free_routed_expert_refs.append(rollout_routed_expert_refs) + out_rollout_routed_expert.append( + torch.as_tensor(rollout_routed_expert, dtype=torch.long) + ) + + seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e + else: + assert isinstance(rollout_routed_experts, torch.Tensor), ( + f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" + ) + rollout_routed_experts_tensor = torch.randint( + low=0, + high=language_cfg.n_routed_experts, + size=( + self.config.pack_max_length, + language_cfg.num_hidden_layers, + language_cfg.num_experts_per_tok, + ), + ) + seq_ctx.rollout_routed_experts = rollout_routed_experts_tensor + + assert seq_ctx.input_ids is not None, "input_ids is None" + assert seq_ctx.rollout_routed_experts.size(0) == seq_ctx.input_ids.size(1) + + if self.sp_mesh is not None and self.sp_mesh.size() > 1: + dist.barrier() + for free_routed_expert_refs in to_free_routed_expert_refs: + ray._private.internal_api.free(free_routed_expert_refs) + del to_free_routed_expert_refs + + seq_ctx = data["seq_ctx"].to(DEVICE) + rollout_logprobs = data.get("rollout_logprobs", None) + if rollout_logprobs is not None: + rollout_logprobs = rollout_logprobs.to(DEVICE) + rollout_logprobs_list.append(rollout_logprobs) + loss_ctx_input = RLLossContextInputItem( + shifted_labels=data["shifted_labels"], + advantages=data["advantages"], + rollout_logprobs=rollout_logprobs, + ).to(DEVICE) + if self.sp_mesh.size() > 1: + seq_ctx = seq_ctx.split(self.sp_mesh) + loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) + seq_ctx_list.append(seq_ctx) + loss_ctx_input_list.append(loss_ctx_input) del data_batches - rank_grad_tokens: torch.Tensor | None = None for loss_ctx_input in loss_ctx_input_list: mask = loss_ctx_input.shifted_labels != -100 @@ -640,9 +647,14 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo avg_kl_div = kl_div_sum / global_grad_tokens if global_grad_tokens > 0 else 0 self.logger.info(f"Rollout {rollout_idx}: avg KL divergence: {avg_kl_div:.4f}") - for i in range(0, len(seq_ctx_list), iters_per_step): - batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] - batches_loss_ctx_input = loss_ctx_input_list[i : i + iters_per_step] + start_idx = 0 + for i in range(actual_optimizer_steps): + num_packs_this_step = packd_batch_num_per_step[i] + end_idx = start_idx + num_packs_this_step + batches_seq_ctx = seq_ctx_list[start_idx:end_idx] + batches_loss_ctx_input = loss_ctx_input_list[start_idx:end_idx] + start_idx = end_idx + LossContext = loss_cfg.loss_ctx_cls batches_loss_kwargs = LossContext.build_batches_loss_kwargs(batches_loss_ctx_input, loss_cfg) engine_input = [] @@ -820,11 +832,20 @@ def get_data_replicate_size(self) -> int: # sp will affect the data replicate size in worker return self._engine.data_replicate_size * self.sp_mesh.size() + @ray_method + def get_dp_rank(self) -> int: + """Get the data parallel rank for this worker.""" + return self.data_mesh["dp"].get_rank() + @ray_method def get_model_cfg(self): model_cfg = self._engine.model_cfg return model_cfg + @ray_method + def get_worker_cfg(self): + return self.config + @ray_method def offload_model(self): self._engine.put_model_to_device("cpu") diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 693fc3f08..5942a2162 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -1,6 +1,5 @@ import json import os -import random from datetime import datetime from pathlib import Path from shutil import rmtree @@ -29,10 +28,11 @@ from xtuner.v1.rl.base import ( TrainingController, TrainingControllerProxy, + TrainingLogInfo, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, - WorkerLogItem, + WorkerInputItem, ) from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker from xtuner.v1.train import ResumeConfig @@ -560,13 +560,25 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste ) with timer("training", step_timer_dict): - workers_log_item: List[WorkerLogItem] = ray.get( + traning_log_info: TrainingLogInfo = ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx ) ) - self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) + workers_log_item = traning_log_info["worker_log_infos"] + self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) + self._writer.add_scalar( + tag="time/pack_time", scalar_value=traning_log_info["pack_time"], global_step=rollout_idx + ) + self._writer.add_scalar( + tag="time/train_time", scalar_value=traning_log_info["train_time"], global_step=rollout_idx + ) + self._writer.add_scalar( + tag="train_metrics/padding_tokens", + scalar_value=traning_log_info["padding_tokens"], + global_step=rollout_idx, + ) rank0_log_item = workers_log_item[0] # These metrics are already aggregated across distributed workers and logging only the metrics from rank 0. rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics") @@ -751,7 +763,8 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf prompt_len_list.append(len(prompt_ids)) response_len_list.append(len(response_ids)) - advantages_list.extend([advantages[i]] * len(response_ids)) + token_advantages = torch.repeat_interleave(advantages[i], len(input_ids)) + advantages_list.extend(token_advantages) shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids assert len(input_ids) <= pack_max_length, f"{len(input_ids)} vs {pack_max_length}" @@ -767,10 +780,10 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf rollout_logprobs = None seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1) - data_dict = { + data_dict: WorkerInputItem = { "seq_ctx": seq_ctx, "shifted_labels": shifted_labels, - "advantage": advantages[i].item(), + "advantages": token_advantages, "rollout_logprobs": rollout_logprobs, } @@ -779,7 +792,6 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert data_batches.append(data_dict) - random.shuffle(data_batches) rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float()