Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 21 additions & 26 deletions recipe/simple_use_case/single_controller_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,39 +67,36 @@ def train_mini_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
assert self.role == "actor"

# 1. Pull data from storage
data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields)
data = tq.kv_batch_get_by_meta(meta=kv_meta)
logger.info(f"train_mini_batch: got data {data}")

# 2. Compute loss
output = compute_loss(data["old_log_prob"], data["ref_log_prob"])
output = TensorDict({"loss": output}, batch_size=output.size(0))
kv_meta.fields.append("loss")

# 3. Write back
tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
logger.info("train_mini_batch: put data done")

return kv_meta

def infer_batch(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
"""Simulate forward-only inference"""
# 1. Pull data from storage
data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields)
data = tq.kv_batch_get_by_meta(meta=kv_meta)
logger.info(f"compute_log_prob: got data {data}")

# 2. Model forward
output = compute_log_prob(data["input_ids"], data["generate_sequences_ids"])
if self.role == "actor":
output = TensorDict({"old_log_prob": output}, batch_size=output.size(0))
kv_meta.fields.append("old_log_prob")
elif self.role == "ref":
output = TensorDict({"ref_log_prob": output}, batch_size=output.size(0))
kv_meta.fields.append("ref_log_prob")
else:
raise ValueError(f"Role {self.role} not supported.")

# 3. Write back
tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
logger.info("infer_batch: put data done")

return kv_meta
Expand Down Expand Up @@ -134,7 +131,7 @@ def __init__(self, config):
tq.init(config)

async def generate(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
data = tq.kv_batch_get(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=kv_meta.fields)
data = tq.kv_batch_get_by_meta(meta=kv_meta)
logger.info(f"demo get data -> generate_sequences {data}")

data = data["input_ids"]
Expand All @@ -151,9 +148,8 @@ async def generate(self, kv_meta: KVBatchMeta) -> KVBatchMeta:
},
batch_size=data.size(0),
)
kv_meta.fields.extend(["generate_sequences_ids", "non_tensor_data", "nested_tensor"])

tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
kv_meta = tq.kv_batch_put(keys=kv_meta.keys, partition_id=kv_meta.partition_id, fields=output)
logger.info("demo Async Server put data to storages done")

return kv_meta
Expand Down Expand Up @@ -240,47 +236,46 @@ def fit(self):
time.sleep(5)

# ========================= Sample generate KVBatchMeta =========================
# TODO: Can be optimized by letting kv_batch_put returns KVBatchMeta directly
sampled_keys = random.sample(batch_keys, self.config.global_batch_size)
gen_meta = KVBatchMeta(
meta = KVBatchMeta(
keys=sampled_keys,
tags=[{} for _ in sampled_keys],
partition_id=f"train_{step}",
fields=["input_ids", "attention_mask"],
)
logger.info(f"demo get gen KVBatchMeta {gen_meta}")
logger.info(f"demo get KVBatchMeta {meta}")

# ========================= Rollout: generate sequences =========================
gen_meta = self.async_rollout_manager.generate_sequences(gen_meta)
logger.info(f"demo get after gen KVBatchMeta {gen_meta}")
meta = self.async_rollout_manager.generate_sequences(meta)
logger.info(f"demo get after gen KVBatchMeta {meta}")

# ========================= Compute ref log prob =========================
gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"]
ref_log_prob_meta = self.actor_rollout_wg.compute_ref_log_prob(gen_meta)
logger.info(f"demo get ref log prob KVBatchMeta: {ref_log_prob_meta}")
meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"]
meta = self.actor_rollout_wg.compute_ref_log_prob(meta)
logger.info(f"demo get ref log prob KVBatchMeta: {meta}")

# ========================= Compute old log prob =========================
gen_meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"]
old_log_prob_meta = self.actor_rollout_wg.compute_log_prob(gen_meta)
logger.info(f"demo get old log prob KVBatchMeta: {old_log_prob_meta}")
meta.fields = ["input_ids", "attention_mask", "generate_sequences_ids"]
meta = self.actor_rollout_wg.compute_log_prob(meta)
logger.info(f"demo get old log prob KVBatchMeta: {meta}")

# ========================= Compute reward =========================
# Simulated inline; in real training this calls a reward model worker
gen_meta.fields = ["generate_sequences_ids", "ref_log_prob", "old_log_prob"]
meta.fields = ["generate_sequences_ids", "ref_log_prob", "old_log_prob"]
logger.info("demo computing reward (simulated)")
time.sleep(1)
logger.info(f"demo reward KVBatchMeta: {gen_meta}")
logger.info(f"demo reward KVBatchMeta: {meta}")

# ========================= Update actor =========================
gen_meta.fields = [
meta.fields = [
"input_ids",
"attention_mask",
"generate_sequences_ids",
"old_log_prob",
"ref_log_prob",
]
train_meta = self.actor_rollout_wg.update_actor(gen_meta)
logger.info(f"demo get after update actor KVBatchMeta: {train_meta}")
meta = self.actor_rollout_wg.update_actor(meta)
logger.info(f"demo get after update actor KVBatchMeta: {meta}")

# ========================= Sync weights to rollout =========================
asyncio.run(self.actor_rollout_wg.update_weights(global_steps=step))
Expand Down
Loading
Loading