From 70878d909f74dfe44b2a92b4608a6d25330927b7 Mon Sep 17 00:00:00 2001 From: yxstev Date: Mon, 23 Mar 2026 14:32:11 +0800 Subject: [PATCH] [refactor] simplify recipe demos - Remove legacy sync/async demos - Add single controller demo - Refactor tq_conf to satisfy new structure - Optimize meta field updates and clean up async interfaces --- recipe/simple_use_case/async_demo.py | 335 ------------------ .../simple_use_case/single_controller_demo.py | 318 +++++++++++++++++ recipe/simple_use_case/sync_demo.py | 222 ------------ 3 files changed, 318 insertions(+), 557 deletions(-) delete mode 100644 recipe/simple_use_case/async_demo.py create mode 100644 recipe/simple_use_case/single_controller_demo.py delete mode 100644 recipe/simple_use_case/sync_demo.py diff --git a/recipe/simple_use_case/async_demo.py b/recipe/simple_use_case/async_demo.py deleted file mode 100644 index 2bad38a2..00000000 --- a/recipe/simple_use_case/async_demo.py +++ /dev/null @@ -1,335 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import logging -import math -import os -import sys -import time -from pathlib import Path - -import ray -import torch -from omegaconf import OmegaConf -from tensordict import NonTensorData, TensorDict - -parent_dir = Path(__file__).resolve().parent.parent.parent -sys.path.append(str(parent_dir)) - -from transfer_queue import ( # noqa: E402 - AsyncTransferQueueClient, - BatchMeta, - SimpleStorageUnit, - TransferQueueController, - process_zmq_server_info, -) -from transfer_queue.utils.common import get_placement_group # noqa: E402 - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -os.environ["RAY_DEDUP_LOGS"] = "0" -os.environ["RAY_DEBUG"] = "1" -ray.init() - - -def compute_old_log_prob(data1, data2): - time.sleep(3) - return data1 - - -def generate_sequences(data): - time.sleep(3) - return data - - -class ActorRolloutRefWorker: - def actor_rollout_wg_generate_sequences(self, data_meta, data_system_client): - # 1. Pull real data from the storage plane through client based on data_meta - data = asyncio.run(data_system_client.async_get_data(data_meta)) - logger.info(f"demo get data->generate_sequences {data}") - - output = generate_sequences(data["input_ids"]) - - output = TensorDict( - { - "generate_sequences_ids": output, - "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]), - "nested_tensor": torch.nested.as_nested_tensor( - [torch.randn(1, 2) for _ in range(output.size(0))], layout=torch.jagged - ), - }, - batch_size=output.size(0), - ) - - # 2. Write results back to the storage plane based on data_meta - asyncio.run(data_system_client.async_put(data=output, metadata=data_meta)) - data_meta.add_fields(output) - logger.info("demo put data to storages done") - - return data_meta - - def actor_rollout_wg_compute_old_log_prob(self, data_meta, data_system_client): - # 1. Pull real data from the storage plane through client based on data_meta - data = asyncio.run(data_system_client.async_get_data(data_meta)) - logger.info(f"demo get data->old_log_prob {data}") - - output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"]) - - output = TensorDict({"old_log_prob": output}, batch_size=output.size(0)) - - # 2. Write results back to the storage plane based on data_meta - asyncio.run(data_system_client.async_put(data=output, metadata=data_meta)) - data_meta.add_fields(output) - logger.info("demo put data to storages done") - - return data_meta - - -@ray.remote -class AsyncvLLMServer: - def __init__(self, config, data_system_controller_info): - self.config = config - self.data_system_client = AsyncTransferQueueClient( - client_id="AsyncvLLMServer", - controller_info=data_system_controller_info, - ) - - self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config) - - async def generate(self, data_meta): - data = await self.data_system_client.async_get_data(data_meta) - logger.info(f"demo get data->generate_sequences {data}") - - data = data["input_ids"] - data += 1 - await asyncio.sleep(3) - - output = TensorDict( - { - "generate_sequences_ids": data, - "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]), - "nested_tensor": torch.nested.as_nested_tensor( - [torch.randn(1, 2) for _ in range(data.size(0))], layout=torch.jagged - ), - }, - batch_size=data.size(0), - ) - - await self.data_system_client.async_put(data=output, metadata=data_meta) - logger.info("demo Async Server put data to storages done") - - return data_meta - - -@ray.remote(num_cpus=1) -class AsyncRolloutWorker: - def __init__( - self, - config, - data_system_controller_info, - ): - self.async_vllm_server = AsyncvLLMServer.remote( - config, - data_system_controller_info, - ) - - async def generate_sequences(self, data_meta_chunk): - tasks = [] - for i in range(data_meta_chunk.size): - # asyncio.create_task cannot directly call Ray Actor methods, - # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx) - tasks.append(asyncio.create_task(self.generate(data_meta_chunk[i]))) - data_metas = await asyncio.gather(*tasks) - return BatchMeta.concat(data_metas) - - async def generate(self, data_meta): - data_meta_new = await self.async_vllm_server.generate.remote(data_meta) - return data_meta_new - - -class RolloutManager: - def __init__(self, config, data_system_storage_unit_infos, data_system_controller_info): - self.config = config - - self.data_system_client = AsyncTransferQueueClient( - client_id="RolloutManager", - controller_info=data_system_controller_info, - ) - - self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config) - - self.async_rollout_workers = [] - num_workers = self.config.rollout_agent_num_workers - for i in range(num_workers): - self.async_rollout_workers.append(AsyncRolloutWorker.remote(config, data_system_controller_info)) - - def generate_sequences(self, data_meta): - data_meta_chunkes = data_meta.chunk(len(self.async_rollout_workers)) - data_metas = ray.get( - [ - worker.generate_sequences.remote(data_meta_chunk) - for worker, data_meta_chunk in zip(self.async_rollout_workers, data_meta_chunkes, strict=True) - ] - ) - batch_meta = BatchMeta.concat(data_metas) - logger.info(f"batch_meta: {batch_meta}") - - return batch_meta - - -class Trainer: - def __init__(self, config): - self.config = config - self.data_system_client = self._initialize_data_system() - self.actor_rollout_wg = ActorRolloutRefWorker() - self.async_rollout_manager = RolloutManager( - self.config, - self.data_system_storage_unit_infos, - self.data_system_controller_info, - ) - - def _initialize_data_system(self): - # 1. Initialize TransferQueueStorage - total_storage_size = self.config.global_batch_size * self.config.num_global_batch * self.config.num_n_samples - self.data_system_storage_units = {} - storage_placement_group = get_placement_group(self.config.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(self.config.num_data_storage_units): - storage_node = SimpleStorageUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_unit_size=math.ceil(total_storage_size / self.config.num_data_storage_units)) - self.data_system_storage_units[storage_unit_rank] = storage_node - logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - - # 2. Initialize TransferQueueController (single controller only) - - # Sampler usage instructions: - # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: - # Option 1: Pass sampler class (will be instantiated automatically) - # self.data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - - # Option 2: Pass sampler instance (if you need custom configuration) - # grpo_sampler = GRPOGroupNSampler() - # self.data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - - # Then use sampling_config in get_meta calls: - self.data_system_controller = TransferQueueController.remote() - logger.info("TransferQueueController has been created.") - - # 3. Prepare necessary information - self.data_system_controller_info = process_zmq_server_info(self.data_system_controller) - self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig - # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict - tq_config.controller_info = self.data_system_controller_info - tq_config.storage_unit_infos = self.data_system_storage_unit_infos - self.config = OmegaConf.merge(tq_config, self.config) - - # 4. Create client - self.data_system_client = AsyncTransferQueueClient( - client_id="Trainer", - controller_info=self.data_system_controller_info, - ) - - self.data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=self.config) - # Note: The client contains ZMQ objects. Currently, we cannot transmit the same client instance - # to multiple places, as this will cause serialization errors in Ray. - # Workaround: If you need to use a client in multiple Ray actors or processes, create a separate - # AsyncTransferQueueClient instance for each actor/process instead of sharing or transmitting the same instance. - return self.data_system_client - - def fit(self): - for epoch in range(1): - train_dataloader = 1 - for step in range(train_dataloader): - input_ids = ( - torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111], [200, 222], [300, 333]]) - ) * (step + 1) - input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0) - prompt_batch = TensorDict( - {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated}, - batch_size=input_ids_repeated.size(0), - ) - - asyncio.run(self.data_system_client.async_put(data=prompt_batch, partition_id=f"train_{step}")) - - logger.info("demo put prompts ok! ") - time.sleep(5) - - batch_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask"], - batch_size=self.config.global_batch_size * self.config.num_n_samples, - partition_id=f"train_{step}", - task_name="generate_sequences", - ) - ) - logger.info(f"demo get meta {batch_meta}") - - # Simulate calling the generate sequences task of the worker group - if not self.config.async_rollout_mode: - batch_meta = self.actor_rollout_wg.actor_rollout_wg_generate_sequences( - batch_meta, self.data_system_client - ) - else: - batch_meta = self.async_rollout_manager.generate_sequences(batch_meta) - log_prob_meta = asyncio.run( - self.data_system_client.async_get_meta( - data_fields=["input_ids", "attention_mask", "generate_sequences_ids"], - batch_size=self.config.global_batch_size * self.config.num_n_samples, - partition_id=f"train_{step}", - task_name="compute_old_log_prob", - ) - ) - logger.info(f"demo get log prob meta: {log_prob_meta}") - - # Simulate calling the compute old log prob task of the worker group - old_log_prob_meta = self.actor_rollout_wg.actor_rollout_wg_compute_old_log_prob( - log_prob_meta, self.data_system_client - ) - - batch_meta = batch_meta.union(old_log_prob_meta) - - # Client notifies controller to clear data status, controller returns metadata; - # Client then notifies the storage plane to clear based on metadata - asyncio.run(self.data_system_client.async_clear_partition(partition_id=f"train_{step}")) - logger.info("clear ok! ") - logger.info("demo done!") - - # Cleanup resources - self.data_system_client.close() - return batch_meta - - -if __name__ == "__main__": - # NOTE: you may choose to set async_rollout_mode=True to test the async rollout mode that mimics - # AgentLoopManager in verl - config_str = """ - global_batch_size: 8 - num_global_batch: 1 - num_data_storage_units: 2 - async_rollout_mode: True - rollout_agent_num_workers: 2 - num_n_samples: 2 - - """ - dict_conf = OmegaConf.create(config_str) - - trainer = Trainer(dict_conf) - trainer.fit() - - ray.shutdown() diff --git a/recipe/simple_use_case/single_controller_demo.py b/recipe/simple_use_case/single_controller_demo.py new file mode 100644 index 00000000..f23f1ebe --- /dev/null +++ b/recipe/simple_use_case/single_controller_demo.py @@ -0,0 +1,318 @@ +# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2025 The TransferQueue Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +import os +import random +import sys +import time +import uuid +from importlib import resources +from pathlib import Path + +import ray +import torch +from omegaconf import OmegaConf +from tensordict import NonTensorData, TensorDict + +import transfer_queue as tq +from transfer_queue import KVBatchMeta + +parent_dir = Path(__file__).resolve().parent.parent.parent +sys.path.append(str(parent_dir)) + +logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + +os.environ["RAY_DEDUP_LOGS"] = "0" +os.environ["RAY_DEBUG"] = "1" +ray.init() + + +def compute_log_prob(data1, _data2): + time.sleep(3) + return data1 + + +def compute_loss(data1, _data2): + time.sleep(3) + return data1 + + +def generate_sequences(data): + time.sleep(3) + return data + + +class TrainingWorker: + def __init__(self, role): + self.role = role + + def train_mini_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + """Simulate multi-mini-batch training loop""" + + assert self.role == "actor" + + # 1. Pull data from storage + data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + logger.info(f"train_mini_batch: got data {data}") + + # 2. Compute loss + output = compute_loss(data["old_log_prob"], data["ref_log_prob"]) + output = TensorDict({"loss": output}, batch_size=output.size(0)) + kv_meta.fields.append("loss") + + # 3. Write back + tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + logger.info("train_mini_batch: put data done") + + return kv_meta + + def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + """Simulate forward-only inference""" + # 1. Pull data from storage + data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + logger.info(f"compute_log_prob: got data {data}") + + # 2. Model forward + output = compute_log_prob(data["input_ids"], data["generate_sequences_ids"]) + if self.role == "actor": + output = TensorDict({"old_log_prob": output}, batch_size=output.size(0)) + kv_meta.fields.append("old_log_prob") + elif self.role == "ref": + output = TensorDict({"ref_log_prob": output}, batch_size=output.size(0)) + kv_meta.fields.append("ref_log_prob") + else: + raise ValueError(f"Role {self.role} not supported.") + + # 3. Write back + tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + logger.info("infer_batch: put data done") + + return kv_meta + + +class ActorRolloutRefWorker: + def __init__(self): + self.actor = TrainingWorker(role="actor") + self.ref = TrainingWorker(role="ref") + + def compute_ref_log_prob(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + output = self.ref.infer_batch(kv_meta) + return output + + def compute_log_prob(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + output = self.actor.infer_batch(kv_meta) + return output + + def update_actor(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + output = self.actor.train_mini_batch(kv_meta) + return output + + async def update_weights(self, global_steps: int = None): + # Simulate weight sync from actor to rollout + logger.info(f"update_weights: syncing weights at step {global_steps}") + await asyncio.sleep(1) + + +@ray.remote +class AsyncvLLMServer: + def __init__(self, config): + tq.init(config) + + async def generate(self, kv_meta: KVBatchMeta) -> KVBatchMeta: + data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields) + logger.info(f"demo get data -> generate_sequences {data}") + + data = data["input_ids"] + data += 1 + await asyncio.sleep(3) + + output = TensorDict( + { + "generate_sequences_ids": data, + "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(data.size(0))]), + "nested_tensor": torch.nested.as_nested_tensor( + [torch.randn(1, 2) for _ in range(data.size(0))], layout=torch.jagged + ), + }, + batch_size=data.size(0), + ) + kv_meta.fields.extend(["generate_sequences_ids", "non_tensor_data", "nested_tensor"]) + + tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output) + logger.info("demo Async Server put data to storages done") + + return kv_meta + + +@ray.remote(num_cpus=1) +class AgentLoopWorker: + def __init__(self, config): + self.async_vllm_server = AsyncvLLMServer.remote(config) + + async def generate_sequences(self, kv_meta_chunk): + if isinstance(kv_meta_chunk, list): + tasks = [] + for item in kv_meta_chunk: + # asyncio.create_task cannot directly call Ray Actor methods, + # otherwise an error will be reported:a coroutine was expected, got ObjectRef(xxx) + tasks.append(asyncio.create_task(self.generate(item))) + kv_metas = await asyncio.gather(*tasks) + return KVBatchMeta.concat(kv_metas) + + elif isinstance(kv_meta_chunk, KVBatchMeta): + kv_meta = await self.generate(kv_meta_chunk) + return kv_meta + + else: + raise TypeError(f"Unsupported type for kv_meta_chunk: {type(kv_meta_chunk)}") + + async def generate(self, kv_meta): + kv_meta_new = await self.async_vllm_server.generate.remote(kv_meta) + return kv_meta_new + + +class AgentLoopManager: + def __init__(self, config): + self.config = config + tq.init(config) + + self.async_rollout_workers = [] + num_workers = self.config.rollout_agent_num_workers + + for _ in range(num_workers): + self.async_rollout_workers.append(AgentLoopWorker.remote(config)) + + def generate_sequences(self, kv_meta): + kv_meta_chunks = kv_meta.chunk(len(self.async_rollout_workers)) + kv_metas = ray.get( + [ + worker.generate_sequences.remote(kv_meta_chunk) + for worker, kv_meta_chunk in zip(self.async_rollout_workers, kv_meta_chunks, strict=True) + ] + ) + kv_meta = KVBatchMeta.concat(kv_metas) + logger.info(f"KVBatchMeta: {kv_meta}") + + return kv_meta + + +class Trainer: + def __init__(self, config): + self.config = config + tq.init(config) + self.tq_client = tq.get_client() + self.actor_rollout_wg = ActorRolloutRefWorker() + self.async_rollout_manager = AgentLoopManager(self.config) + + def fit(self): + for _epoch in range(1): + train_dataloader = 1 + for step in range(train_dataloader): + # ========================= Construct prompt batch data ========================= + input_ids = ( + torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [100, 111], [200, 222], [300, 333]]) + ) * (step + 1) + input_ids_repeated = torch.repeat_interleave(input_ids, self.config.num_n_samples, dim=0) + batch_keys = [str(uuid.uuid4()) for _ in range(len(input_ids_repeated))] + prompt_batch = TensorDict( + {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated}, + batch_size=input_ids_repeated.size(0), + ) + + # ========================= Put prompts to TQ system ========================= + tq.kv_batch_put(keys=batch_keys, partition_id=f"train_{step}", fields=prompt_batch) + logger.info("demo put prompts ok! ") + time.sleep(5) + + # ========================= Sample generate KVBatchMeta ========================= + # TODO: Can be optimized by letting kv_batch_put returns KVBatchMeta directly + sampled_keys = random.sample(batch_keys, self.config.global_batch_size) + gen_meta = KVBatchMeta( + keys=sampled_keys, + tags=[{} for _ in sampled_keys], + partition_id=f"train_{step}", + fields=["input_ids", "attention_mask"], + ) + logger.info(f"demo get gen KVBatchMeta {gen_meta}") + + # ========================= Rollout: generate sequences ========================= + gen_meta = self.async_rollout_manager.generate_sequences(gen_meta) + logger.info(f"demo get after gen KVBatchMeta {gen_meta}") + + # ========================= Compute ref log prob ========================= + gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] + ref_log_prob_meta = self.actor_rollout_wg.compute_ref_log_prob(gen_meta) + logger.info(f"demo get ref log prob KVBatchMeta: {ref_log_prob_meta}") + + # ========================= Compute old log prob ========================= + gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"] + old_log_prob_meta = self.actor_rollout_wg.compute_log_prob(gen_meta) + logger.info(f"demo get old log prob KVBatchMeta: {old_log_prob_meta}") + + # ========================= Compute reward ========================= + # Simulated inline; in real training this calls a reward model worker + gen_meta.fields = ["generate_sequences_ids", "ref_log_prob", "old_log_prob"] + logger.info("demo computing reward (simulated)") + time.sleep(1) + logger.info(f"demo reward KVBatchMeta: {gen_meta}") + + # ========================= Update actor ========================= + gen_meta.fields = [ + "input_ids", + "attention_mask", + "generate_sequences_ids", + "old_log_prob", + "ref_log_prob", + ] + train_meta = self.actor_rollout_wg.update_actor(gen_meta) + logger.info(f"demo get after update actor KVBatchMeta: {train_meta}") + + # ========================= Sync weights to rollout ========================= + asyncio.run(self.actor_rollout_wg.update_weights(global_steps=step)) + logger.info("demo update weights done") + + # ========================= Clear partition in TQ ========================= + self.tq_client.clear_partition(partition_id=f"train_{step}") + logger.info("clear ok! ") + logger.info("demo done!") + + # Cleanup resources + self.tq_client.close() + + +if __name__ == "__main__": + # Demo-level training hyperparameters (not part of TQ config) + demo_conf = OmegaConf.create( + { + "global_batch_size": 8, + "num_global_batch": 1, + "rollout_agent_num_workers": 2, + "num_n_samples": 2, + } + ) + + # Load default TQ config and override as needed + tq_conf = OmegaConf.load(resources.files("transfer_queue") / "config.yaml") + tq_conf = OmegaConf.merge(tq_conf, {"backend": {"SimpleStorage": {"num_data_storage_units": 2}}}) + + dict_conf = OmegaConf.merge(demo_conf, tq_conf) + + trainer = Trainer(dict_conf) + trainer.fit() + + ray.shutdown() diff --git a/recipe/simple_use_case/sync_demo.py b/recipe/simple_use_case/sync_demo.py deleted file mode 100644 index e6513274..00000000 --- a/recipe/simple_use_case/sync_demo.py +++ /dev/null @@ -1,222 +0,0 @@ -# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2025 The TransferQueue Team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import math -import os -import sys -import time -from pathlib import Path - -import ray -import torch -from omegaconf import OmegaConf -from tensordict import NonTensorData, TensorDict - -parent_dir = Path(__file__).resolve().parent.parent.parent -sys.path.append(str(parent_dir)) - -from transfer_queue import ( # noqa: E402 - SimpleStorageUnit, - TransferQueueClient, - TransferQueueController, - process_zmq_server_info, -) -from transfer_queue.utils.common import get_placement_group # noqa: E402 - -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") -logger = logging.getLogger(__name__) - -os.environ["RAY_DEDUP_LOGS"] = "0" -os.environ["RAY_DEBUG"] = "1" -ray.init() - - -def initialize_data_system(config): - # 1. Initialize TransferQueueStorage - total_storage_size = config.global_batch_size * config.num_global_batch * config.num_n_samples - data_system_storage_units = {} - storage_placement_group = get_placement_group(config.num_data_storage_units, num_cpus_per_actor=1) - for storage_unit_rank in range(config.num_data_storage_units): - storage_node = SimpleStorageUnit.options( - placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank - ).remote(storage_unit_size=math.ceil(total_storage_size / config.num_data_storage_units)) - data_system_storage_units[storage_unit_rank] = storage_node - logger.info(f"SimpleStorageUnit #{storage_unit_rank} has been created.") - - # 2. Initialize TransferQueueController (single controller only) - - # Sampler usage instructions: - # For GRPO grouped sampling, you can initialize the controller with GRPOGroupNSampler: - # Option 1: Pass sampler class (will be instantiated automatically) - # data_system_controller = TransferQueueController.remote(sampler=GRPOGroupNSampler) - - # Option 2: Pass sampler instance (if you need custom configuration) - # grpo_sampler = GRPOGroupNSampler() - # data_system_controller = TransferQueueController.remote(sampler=grpo_sampler) - - # Then use sampling_config in get_meta calls: - data_system_controller = TransferQueueController.remote() - logger.info("TransferQueueController has been created.") - - # 3. Prepare necessary information - data_system_controller_info = process_zmq_server_info(data_system_controller) - data_system_storage_unit_infos = process_zmq_server_info(data_system_storage_units) - - tq_config = OmegaConf.create({}, flags={"allow_objects": True}) # Note: Need to generate a new DictConfig - # with allow_objects=True to maintain ZMQServerInfo instance. Otherwise it will be flattened to dict - tq_config.controller_info = data_system_controller_info - tq_config.storage_unit_infos = data_system_storage_unit_infos - config = OmegaConf.merge(tq_config, config) - - # 4. Create client - data_system_client = TransferQueueClient( - client_id="Trainer", - controller_info=data_system_controller_info, - ) - - data_system_client.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config) - - return data_system_controller, data_system_storage_units, data_system_client - - -def generate_sequences(data): - time.sleep(3) - return data - - -def compute_old_log_prob(data1, _data2): - time.sleep(3) - return data1 - - -def actor_rollout_wg_generate_sequences(data_meta, data_system_client): - # 1. Pull real data from the storage plane through client based on data_meta - data = data_system_client.get_data(data_meta) - logger.info(f"demo get data {data}") - - output = generate_sequences(data["input_ids"]) - - output = TensorDict( - { - "generate_sequences_ids": output, - "non_tensor_data": torch.stack([NonTensorData("test_str") for _ in range(output.size(0))]), - "nested_tensor": torch.nested.as_nested_tensor( - [torch.randn(1, 2) for _ in range(output.size(0))], layout=torch.jagged - ), - }, - batch_size=output.size(0), - ) - - # 2. Write results back to the storage plane based on data_meta - data_system_client.put(data=output, metadata=data_meta) - data_meta.add_fields(output) - logger.info("demo put data to storages done") - - return data_meta - - -def actor_rollout_wg_compute_old_log_prob(data_meta, data_system_client): - # 1. Pull real data from the storage plane through client based on data_meta - data = data_system_client.get_data(data_meta) - logger.info(f"demo get data {data}") - - output = compute_old_log_prob(data["input_ids"], data["generate_sequences_ids"]) - - output = TensorDict({"old_log_prob": output}, batch_size=output.size(0)) - - # 2. Write results back to the storage plane based on data_meta - data_system_client.put(data=output, metadata=data_meta) - data_meta.add_fields(output) - logger.info("demo put data to storages done") - - return data_meta - - -# Simulate the fit function of the trainer -def fit(config, data_system_client): - for _epoch in range(1): - train_dataloader = 1 - for step in range(train_dataloader): - input_ids = (torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8], [10, 11], [100, 111]])) * (step + 1) - input_ids_repeated = torch.repeat_interleave(input_ids, config.num_n_samples, dim=0) - prompt_batch = TensorDict( - {"input_ids": input_ids_repeated, "attention_mask": input_ids_repeated}, - batch_size=input_ids_repeated.size(0), - ) - - data_system_client.put(data=prompt_batch, partition_id=f"train_{step}") - logger.info("demo put prompts ok! ") - time.sleep(5) - - batch_meta = data_system_client.get_meta( - data_fields=["input_ids", "attention_mask"], - batch_size=config.global_batch_size, - partition_id=f"train_{step}", - task_name="generate_sequences", - ) - # Set output fields for RL training - in this case, we want to generate sequences from input_ids - logger.info(f"demo get meta {batch_meta}") - - # Simulate calling the generate sequences task of the worker group - batch_meta = actor_rollout_wg_generate_sequences(batch_meta, data_system_client) - log_prob_meta = data_system_client.get_meta( - data_fields=["input_ids", "attention_mask", "generate_sequences_ids"], - batch_size=config.global_batch_size, - partition_id=f"train_{step}", - task_name="compute_old_log_prob", - ) - # Set output fields for RL training - we want to compute log probs for the generated sequences - logger.info(f"demo get log prob meta: {log_prob_meta}") - - # Simulate calling the compute old log prob task of the worker group - old_log_prob_meta = actor_rollout_wg_compute_old_log_prob(log_prob_meta, data_system_client) - - batch_meta = batch_meta.union(old_log_prob_meta) - - # For the master client, notify all controllers to clear data status, master returns metadata; - # Client then notifies the storage plane to clear based on metadata - # Client selects one master controller to get metadata, - # other controllers directly clear without returning metadata - data_system_client.clear_partition(partition_id=f"train_{step}") - logger.info("clear ok! ") - logger.info("demo done!") - - -def main(config): - # Initialize Data System: Launching the Controller and Storage based on Ray - _data_system_controller, _data_system_storage_units, data_system_client = initialize_data_system(config) - import time - - time.sleep(5) - - fit(config, data_system_client) - - # Cleanup resources - data_system_client.close() - - -if __name__ == "__main__": - config_str = """ - global_batch_size: 6 - num_global_batch: 1 - num_data_storage_units: 2 - num_n_samples: 2 - """ - dict_conf = OmegaConf.create(config_str) - - main(dict_conf) - - ray.shutdown()