Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions examples/v1/config/rl_colocate_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""RL Colocate Trainer 示例配置(GRPO + GSM8K)。

用法:通过环境变量传入路径后,由 CLI 加载本配置并 trainer_cfg.build().fit()。
需设置: WORK_DIR, MODEL_PATH, DATA_PATH, EVAL_DATA_PATH
可选: WORLD_SIZE, ENABLE_RETURN_ROUTED_EXPERTS, LOSS_TYPE, LOSS_MODE, SP_SIZE
"""
import os
from pathlib import Path

from xtuner.v1.config import AdamWConfig, FSDPConfig, LRConfig
from xtuner.v1.data_proto import SampleParams
from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig
from xtuner.v1.datasets.rl_tokenize_fn import RLTextTokenizeFnConfig
from xtuner.v1.model import get_model_config_from_hf
from xtuner.v1.ray.base import AcceleratorResourcesConfig
from xtuner.v1.ray.config.worker import RolloutConfig
from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig
from xtuner.v1.rl.base import WorkerConfig
from xtuner.v1.rl.base.agent_loop import SingleTurnAgentLoopConfig
from xtuner.v1.rl.base.agent_loop_manager import AgentLoopManagerConfig
from xtuner.v1.rl.base.producer import SyncProduceStrategyConfig
from xtuner.v1.rl.base.sampler import SamplerConfig
from xtuner.v1.rl.evaluator import EvaluatorConfig
from xtuner.v1.rl.grpo import GRPOLossConfig
from xtuner.v1.train.rl_colocate_trainer import RLColocateTrainerConfig

# env
work_dir = os.environ["WORK_DIR"]
model_path = os.environ["MODEL_PATH"]
data_path = os.environ["DATA_PATH"]
eval_data_path = os.environ["EVAL_DATA_PATH"]
enable_return_routed_experts = os.environ.get("ENABLE_RETURN_ROUTED_EXPERTS", "0")
WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "1"))

# basic settings
experimental_name = "grpo_gsm8k"
rollout_steps = 45
evaluate_step = 45
train_optimizer_steps = 1
global_batch_size = 64 * train_optimizer_steps
prompt_repeat_k = 5
rollout_tp_size = 1
rollout_ep_size = 1
max_prompt_length = 512
max_response_length = 1024
pack_max_length = 32 * 1024

# 1. resources
resources = AcceleratorResourcesConfig(
accelerator="GPU",
num_workers=8 * WORLD_SIZE,
num_cpus_per_worker=12,
cpu_memory_per_worker=16 * 1024**3, # 16 GB
)

# 2. rollout
rollout_config = RolloutConfig(
env=experimental_name,
device=resources.accelerator,
model_path=model_path,
dtype="bfloat16",
tensor_parallel_size=rollout_tp_size,
expert_parallel_size=rollout_ep_size,
gpu_memory_utilization=0.8,
context_length=max_response_length + max_prompt_length,
enable_return_routed_experts=(enable_return_routed_experts == "1"),
)

# 3. judger
judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k")

# 4. train worker
lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6)
fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1)
model_cfg = get_model_config_from_hf(Path(model_path))
if hasattr(model_cfg, "balancing_loss_cfg"):
model_cfg.balancing_loss_cfg = None
if hasattr(model_cfg, "z_loss_cfg"):
model_cfg.z_loss_cfg = None
optim_cfg = AdamWConfig(lr=1e-6, foreach=False, weight_decay=0.1)
loss_cfg = GRPOLossConfig(
policy_loss_cfg=dict(
cliprange_high=0.28,
cliprange_low=0.2,
loss_type=os.environ.get("LOSS_TYPE", "vanilla"),
clip_ratio_c=10.0,
log_prob_diff_min=-20.0,
log_prob_diff_max=20.0,
),
ignore_idx=-100,
use_kl_loss=False,
kl_loss_coef=0.0,
kl_loss_type="low_var_kl",
mode=os.environ.get("LOSS_MODE", "chunk"),
chunk_size=512,
)
train_worker_cfg = WorkerConfig(
model_cfg=model_cfg,
load_from=model_path,
optim_cfg=optim_cfg,
loss_cfg=loss_cfg,
lr_cfg=lr_cfg,
fsdp_cfg=fsdp_cfg,
sp_size=int(os.environ.get("SP_SIZE", "1")),
optimizer_steps=train_optimizer_steps,
pack_max_length=pack_max_length,
)

# 5. train agent loop manager
train_dataset = DatasetConfig(name=experimental_name, anno_path=data_path)
tokenizer_config = RLTextTokenizeFnConfig(max_length=max_prompt_length)
train_dataset_cfg = [{"dataset": train_dataset, "tokenize_fn": tokenizer_config}]
dataloader_cfg = DataloaderConfig(
dataset_config_list=train_dataset_cfg,
pack_max_length=pack_max_length,
collator="fake_collator",
pack_level="none",
)
sampler_config = SamplerConfig(
dataloader_cfg=dataloader_cfg,
prompt_repeat_k=prompt_repeat_k,
)
training_sample_params = SampleParams(
max_tokens=max_response_length,
top_k=0,
top_p=1.0,
temperature=1.0,
min_tokens=0,
)
agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=training_sample_params,
)
produce_strategy_config = SyncProduceStrategyConfig()
agent_loop_manager_cfg = AgentLoopManagerConfig(
task_name="train_task",
agent_loop_config=agent_loop_config,
produce_strategy_config=produce_strategy_config,
sampler_config=sampler_config,
)

# 6. eval agent loop manager
eval_dataset = DatasetConfig(
name=experimental_name, anno_path=eval_data_path, sample_ratio=1.0
)
eval_dataset_cfg = [{"dataset": eval_dataset, "tokenize_fn": tokenizer_config}]
eval_dataloader_cfg = DataloaderConfig(
dataset_config_list=eval_dataset_cfg,
pack_max_length=pack_max_length,
collator="fake_collator",
pack_level="none",
)
eval_sampler_config = SamplerConfig(
dataloader_cfg=eval_dataloader_cfg,
prompt_repeat_k=1,
)
evaluation_sample_params = SampleParams(
max_tokens=max_response_length,
top_k=1,
top_p=1.0,
temperature=0.0,
min_tokens=0,
)
eval_agent_loop_config = SingleTurnAgentLoopConfig(
hf_checkpoint=model_path,
sample_params=evaluation_sample_params,
)
eval_agent_loop_manager_cfg = AgentLoopManagerConfig(
task_name="eval_task",
agent_loop_config=eval_agent_loop_config,
sampler_config=eval_sampler_config,
)

# 7. evaluator
evaluator_config = EvaluatorConfig(compute_metric_func=None)

# 8. RL Colocate Trainer Config(CLI 通过 config["trainer"].build() 得到 Trainer)
trainer = RLColocateTrainerConfig(
resources=resources,
train_worker_cfg=train_worker_cfg, # TODO: uniform naming of cfg and config
rollout_config=rollout_config,
judger_config=judger_config,
tokenizer_path=model_path,
replay_buffer_config=dict(),
agent_loop_manager_cfg=agent_loop_manager_cfg,
eval_agent_loop_manager_cfg=eval_agent_loop_manager_cfg,
evaluator_config=evaluator_config,
load_from=model_path,
rollout_steps=rollout_steps,
global_batch_size=global_batch_size,
enable_evaluate=True,
enable_initial_evaluate=False,
evaluate_step=evaluate_step,
work_dir=work_dir,
seed=123,
debug_rollout=False,
)
4 changes: 2 additions & 2 deletions xtuner/v1/train/cli/rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from cyclopts.group import Group

from xtuner.v1.rl.utils import register_cleanup
from xtuner.v1.train.rl_trainer import RLTrainer
from xtuner.v1.utils import Config
from xtuner.v1.utils.track_rl_mem import monitor_actor_memory

Expand Down Expand Up @@ -56,7 +55,8 @@ def main(
track_thread.start()

trainer_cfg = Config.fromfile(config)["trainer"]
trainer = RLTrainer.from_config(trainer_cfg)
# trainer = RLTrainer.from_config(trainer_cfg)
trainer = trainer_cfg.build()
trainer.fit()

if dist.is_initialized():
Expand Down
Loading
Loading